Keras 3 API documentation / RNG API / SeedGenerator class

SeedGenerator class

[source]

SeedGenerator class

keras.random.SeedGenerator(seed=None, name=None, **kwargs)

Generates variable seeds upon each call to a function generating random numbers.

In Keras, all random number generators (such as keras.random.normal()) are stateless, meaning that if you pass an integer seed to them (such as seed=42), they will return the same values for repeated calls. To get different values for each call, a SeedGenerator providing the state of the random generator has to be used.

Note that all the random number generators have a default seed of None, which implies that an internal global SeedGenerator is used. If you need to decouple the RNG from the global state you can provide a local StateGenerator with either a deterministic or random initial state.

Remark concerning the JAX backen: Note that the use of a local StateGenerator as seed argument is required for JIT compilation of RNG with the JAX backend, because the use of global state is not supported.

Example

seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)

Usage in a layer:

class Dropout(keras.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, x, training=False):
        if training:
            return keras.random.dropout(
                x, rate=0.5, seed=self.seed_generator
            )
        return x