SeedGenerator
classkeras.random.SeedGenerator(seed=None, name=None, **kwargs)
Generates variable seeds upon each call to a function generating random numbers.
In Keras, all random number generators (such as
keras.random.normal()
) are stateless, meaning that if you pass an
integer seed to them (such as seed=42
), they will return the same
values for repeated calls. To get different values for each
call, a SeedGenerator
providing the state of the random generator
has to be used.
Note that all the random number generators have a default seed of None,
which implies that an internal global SeedGenerator is used.
If you need to decouple the RNG from the global state you can provide
a local StateGenerator
with either a deterministic or random initial
state.
Remark concerning the JAX backen: Note that the use of a local
StateGenerator
as seed argument is required for JIT compilation of
RNG with the JAX backend, because the use of global state is not
supported.
Example
seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)
Usage in a layer:
class Dropout(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, x, training=False):
if training:
return keras.random.dropout(
x, rate=0.5, seed=self.seed_generator
)
return x