RandomGrayscale layer

[source]

RandomGrayscale class

keras.layers.RandomGrayscale(factor=0.5, data_format=None, seed=None, **kwargs)

Preprocessing layer for random conversion of RGB images to grayscale.

This layer randomly converts input images to grayscale with a specified
factor. When applied, it maintains the original number of channels
but sets all channels to the same grayscale value. This can be useful
for data augmentation and training models to be robust to color
variations.

The conversion preserves the perceived luminance of the original color
image using standard RGB to grayscale conversion coefficients. Images
that are not selected for conversion remain unchanged.

**Note:** This layer is safe to use inside a [`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) or `grain` pipeline
(independently of which backend you're using).

# Arguments
    factor: Float between 0 and 1, specifying the factor of
        converting each image to grayscale. Defaults to 0.5. A value of
        1.0 means all images will be converted, while 0.0 means no images
        will be converted.
    data_format: String, one of `"channels_last"` (default) or
        `"channels_first"`. The ordering of the dimensions in the inputs.
        `"channels_last"` corresponds to inputs with shape
        `(batch, height, width, channels)` while `"channels_first"`
        corresponds to inputs with shape
        `(batch, channels, height, width)`.

# Input shape
    3D (unbatched) or 4D (batched) tensor with shape:
    `(..., height, width, channels)`, in `"channels_last"` format,
    or `(..., channels, height, width)`, in `"channels_first"` format.

# Output shape
    Same as input shape. The output maintains the same number of channels
    as the input, even for grayscale-converted images where all channels
    will have the same value.

# Example
layer = keras.layers.RandomGrayscale(value_range=(0, 255))
images = np.random.randint(0, 255, (8, 224, 224, 3), dtype="uint8")

labels = keras.ops.one_hot(
    np.array([0, 1, 2, 0, 1, 2, 0, 1]),
    num_classes=3
)

segmentation_masks = np.random.randint(0, 3, (8, 224, 224, 1), dtype="uint8")

output = layer(
    {
        "images": images,
        "labels": labels,
        "segmentation_masks": segmentation_masks
    },
    training=True
)