KerasRS / API documentation / Embedding Layers / DistributedEmbedding layer

DistributedEmbedding layer

[source]

DistributedEmbedding class

keras_rs.layers.DistributedEmbedding(
    feature_configs: Union[
        keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
        tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
        Sequence[
            Union[
                keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
                tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
                ForwardRef("Nested[T]"),
            ]
        ],
        Mapping[
            str,
            Union[
                keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
                tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
                ForwardRef("Nested[T]"),
            ],
        ],
    ],
    table_stacking: Union[str, Sequence[str], Sequence[Sequence[str]]] = "auto",
    **kwargs: Any
)

DistributedEmbedding, a layer for accelerated large embedding lookups.


Note: DistributedEmbedding is in Preview.


Configuration

A DistributedEmbedding embedding layer is configured via a set of keras_rs.layers.FeatureConfig objects, which themselves refer to keras_rs.layers.TableConfig objects.

  • TableConfig defines an embedding table with parameters such as its vocabulary size, embedding dimension, as well as a combiner for reduction and optimizer for training.
  • FeatureConfig defines what input features the DistributedEmbedding will handle and which embedding table to use. Note that multiple features can use the same embedding table.
table1 = keras_rs.layers.TableConfig(
    name="table1",
    vocabulary_size=TABLE1_VOCABULARY_SIZE,
    embedding_dim=TABLE1_EMBEDDING_SIZE,
)
table2 = keras_rs.layers.TableConfig(
    name="table2",
    vocabulary_size=TABLE2_VOCABULARY_SIZE,
    embedding_dim=TABLE2_EMBEDDING_SIZE,
)

feature1 = keras_rs.layers.FeatureConfig(
    name="feature1",
    table=table1,
    input_shape=(PER_REPLICA_BATCH_SIZE,),
    output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
)
feature2 = keras_rs.layers.FeatureConfig(
    name="feature2",
    table=table2,
    input_shape=(PER_REPLICA_BATCH_SIZE,),
    output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
)

feature_configs = {
    "feature1": feature1,
    "feature2": feature2,
}

embedding = keras_rs.layers.DistributedEmbedding(feature_configs)

Optimizers

Each embedding table within DistributedEmbedding uses its own optimizer for training, which is independent from the optimizer set on the model via model.compile().

Note that not all optimizers are supported. Currently, the following are always supported (i.e. on all backends and accelerators):

Additionally, not all parameters of the optimizers are supported (e.g. the nesterov option of SGD). An error is raised when an unsupported optimizer or an unsupported optimizer parameter is used.

Arguments

  • feature_configs: A nested structure of keras_rs.layers.FeatureConfig.
  • table_stacking: The table stacking to use. None means no table stacking. "auto" means to stack tables automatically. A list of table names or list of lists of table names means to stack the tables in the inner lists together. Note that table stacking is not supported on older TPUs, in which case the default value of "auto" will be interpreted as no table stacking.
  • **kwargs: Additional arguments to pass to the layer base class.

[source]

call method

DistributedEmbedding.call(
    inputs: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
    ],
    weights: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
        NoneType,
    ] = None,
    training: bool = False,
)

Lookup features in embedding tables and apply reduction.

Arguments

  • inputs: A nested structure of 2D tensors to embed and reduce. The structure must be the same as the feature_configs passed during construction.
  • weights: An optional nested structure of 2D tensors of weights to apply before reduction. When present, the structure must be the same as inputs and the shapes must match.

Returns

A nested structure of dense 2D tensors, which are the reduced embeddings from the passed features. The structure is the same as inputs.