Keras 3 API documentation / Ops API / Einops ops

Einops ops

[source]

rearrange function

keras.ops.rearrange(tensor, pattern, **axes_lengths)

Rearranges the axes of a Keras tensor according to a specified pattern, einops-style.

Arguments

  • tensor: Input Keras tensor.
  • pattern: String describing the rearrangement in einops notation.
  • **axes_lengths: Keyword arguments specifying lengths of axes when axes decomposition is used.

Returns

  • Tensor: A Keras tensor with rearranged axes.

Follows the logic of:

  1. If decomposition is needed, reshape to match decomposed dimensions.
  2. Permute known and inferred axes to match the form of the output.
  3. Reshape to match the desired output shape.

Example Usage:

```python
>>> import numpy as np
>>> from keras.ops import rearrange
>>> images = np.random.rand(32, 30, 40, 3) # BHWC format

Reordering to BCHW

>>> rearrange(images, 'b h w c -> b c h w').shape
TensorShape([32, 3, 30, 40])

"Merge" along first axis - concat images from a batch

>>> rearrange(images, 'b h w c -> (b h) w c').shape
TensorShape([960, 40, 3])

"Merge" along second axis - concat images horizontally

>>> rearrange(images, 'b h w c -> h (b w) c').shape
TensorShape([30, 1280, 3])

Flatten images into a CHW vector

>>> rearrange(images, 'b h w c -> b (c h w)').shape
TensorShape([32, 3600])

Decompose H and W axes into 4 smaller patches

>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
TensorShape([128, 15, 20, 3])

Space-to-depth decomposition of input axes

>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
TensorShape([32, 15, 20, 12])

```