Keras 3 API documentation / KerasCV / Layers / Augmentation layers / ChannelShuffle layer

ChannelShuffle layer

[source]

ChannelShuffle class

keras_cv.layers.ChannelShuffle(groups=3, seed=None, **kwargs)

Shuffle channels of an input image.

Input shape

3D (unbatched) or 4D (batched) tensor with shape: (..., height, width, channels), in "channels_last" format

Output shape

3D (unbatched) or 4D (batched) tensor with shape: (..., height, width, channels), in "channels_last" format

Arguments

  • groups: Number of groups to divide the input channels, defaults to 3.
  • seed: Integer. Used to create a random seed.

Usage:

(images, labels), _ = keras.datasets.cifar10.load_data()
channel_shuffle = ChannelShuffle(groups=3)
augmented_images = channel_shuffle(images)