Getting started with Keras Core

Author: fchollet
Date created: 2023/07/10
Last modified: 2023/07/10
Description: First contact with the new multi-backend Keras.

View in Colab GitHub source


Introduction

Keras Core is a full implementation of the Keras API that works with TensorFlow, JAX, and PyTorch interchangeably. This notebook will walk you through key Keras Core workflows.

First, let's install Keras Core:

!pip install -q keras-core

Setup

We're going to be using the JAX backend here -- but you can edit the string below to "tensorflow" or "torch" and hit "Restart runtime", and the whole notebook will run just the same! This entire guide is backend-agnostic.

import numpy as np
import os

os.environ["KERAS_BACKEND"] = "jax"

# Note that keras_core should only be imported after the backend
# has been configured. The backend cannot be changed once the
# package is imported.
import keras_core as keras
Using JAX backend.

A first example: A MNIST convnet

Let's start with the Hello World of ML: training a convnet to classify MNIST digits.

Here's the data:

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
x_train shape: (60000, 28, 28, 1)
y_train shape: (60000,)
60000 train samples
10000 test samples

Here's our model.

Different model-building options that Keras offers include:

# Model parameters
num_classes = 10
input_shape = (28, 28, 1)

model = keras.Sequential(
    [
        keras.layers.Input(shape=input_shape),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

Here's our model summary:

model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 26, 26, 64)        │        640 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_1 (Conv2D)               │ (None, 24, 24, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 12, 12, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_2 (Conv2D)               │ (None, 10, 10, 128)       │     73,856 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_3 (Conv2D)               │ (None, 8, 8, 128)         │    147,584 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ global_average_pooling2d        │ (None, 128)               │          0 │
│ (GlobalAveragePooling2D)        │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout)               │ (None, 128)               │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense (Dense)                   │ (None, 10)                │      1,290 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 260,298 (7.94 MB)
 Trainable params: 260,298 (7.94 MB)
 Non-trainable params: 0 (0.00 B)

We use the compile() method to specify the optimizer, loss function, and the metrics to monitor. Note that with the JAX and TensorFlow backends, XLA compilation is turned on by default.

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)

Let's train and evaluate the model. We'll set aside a validation split of 15% of the data during training to monitor generalization on unseen data.

batch_size = 128
epochs = 20

callbacks = [
    keras.callbacks.ModelCheckpoint(filepath="model_at_epoch_{epoch}.keras"),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=2),
]

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.15,
    callbacks=callbacks,
)
score = model.evaluate(x_test, y_test, verbose=0)
Epoch 1/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 172ms/step - acc: 0.5219 - loss: 1.3333 - val_acc: 0.9581 - val_loss: 0.1400
Epoch 2/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 172ms/step - acc: 0.9260 - loss: 0.2497 - val_acc: 0.9743 - val_loss: 0.0881
Epoch 3/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 171ms/step - acc: 0.9505 - loss: 0.1634 - val_acc: 0.9839 - val_loss: 0.0567
Epoch 4/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 170ms/step - acc: 0.9611 - loss: 0.1314 - val_acc: 0.9798 - val_loss: 0.0673
Epoch 5/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 171ms/step - acc: 0.9696 - loss: 0.1059 - val_acc: 0.9873 - val_loss: 0.0459
Epoch 6/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 174ms/step - acc: 0.9729 - loss: 0.0920 - val_acc: 0.9891 - val_loss: 0.0400
Epoch 7/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 70s 176ms/step - acc: 0.9755 - loss: 0.0828 - val_acc: 0.9890 - val_loss: 0.0360
Epoch 8/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 171ms/step - acc: 0.9775 - loss: 0.0761 - val_acc: 0.9904 - val_loss: 0.0359
Epoch 9/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 70s 174ms/step - acc: 0.9795 - loss: 0.0663 - val_acc: 0.9903 - val_loss: 0.0342
Epoch 10/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 170ms/step - acc: 0.9813 - loss: 0.0621 - val_acc: 0.9904 - val_loss: 0.0360
Epoch 11/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 170ms/step - acc: 0.9817 - loss: 0.0575 - val_acc: 0.9910 - val_loss: 0.0285
Epoch 12/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 67s 169ms/step - acc: 0.9845 - loss: 0.0525 - val_acc: 0.9908 - val_loss: 0.0346
Epoch 13/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 67s 168ms/step - acc: 0.9843 - loss: 0.0514 - val_acc: 0.9917 - val_loss: 0.0283
Epoch 14/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 67s 168ms/step - acc: 0.9859 - loss: 0.0481 - val_acc: 0.9924 - val_loss: 0.0242
Epoch 15/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 169ms/step - acc: 0.9863 - loss: 0.0466 - val_acc: 0.9920 - val_loss: 0.0269
Epoch 16/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 170ms/step - acc: 0.9867 - loss: 0.0445 - val_acc: 0.9936 - val_loss: 0.0228
Epoch 17/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 174ms/step - acc: 0.9875 - loss: 0.0423 - val_acc: 0.9926 - val_loss: 0.0261
Epoch 18/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 174ms/step - acc: 0.9875 - loss: 0.0408 - val_acc: 0.9934 - val_loss: 0.0231

During training, we were saving a model at the end of each epoch. You can also save the model in its latest state like this:

model.save("final_model.keras")

And reload it like this:

model = keras.saving.load_model("final_model.keras")

Next, you can query predictions of class probabilities with predict():

predictions = model.predict(x_test)
 313/313 ━━━━━━━━━━━━━━━━━━━━ 3s 10ms/step

That's it for the basics!


Writing cross-framework custom components

Keras Core enables you to write custom Layers, Models, Metrics, Losses, and Optimizers that work across TensorFlow, JAX, and PyTorch with the same codebase. Let's take a look at custom layers first.

If you're already familiar with writing custom layers in tf.keras -- well, nothing has changed. Except one thing: instead of using functions from the tf namespace, you should use functions from keras.ops.*.

The keras.ops namespace contains:

  • An implementation of the NumPy API, e.g. keras.ops.stack or keras.ops.matmul.
  • A set of neural network specific ops that are absent from NumPy, such as keras.ops.conv or keras.ops.binary_crossentropy.

Let's make a custom Dense layer that works with all backends:

class MyDense(keras.layers.Layer):
    def __init__(self, units, activation=None, name=None):
        super().__init__(name=name)
        self.units = units
        self.activation = keras.activations.get(activation)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.w = self.add_weight(
            shape=(input_dim, self.units),
            initializer=keras.initializers.GlorotNormal(),
            name="kernel",
            trainable=True,
        )

        self.b = self.add_weight(
            shape=(self.units,),
            initializer=keras.initializers.Zeros(),
            name="bias",
            trainable=True,
        )

    def call(self, inputs):
        # Use Keras ops to create backend-agnostic layers/metrics/etc.
        x = keras.ops.matmul(inputs, self.w) + self.b
        return self.activation(x)

Next, let's make a custom Dropout layer that relies on the keras.random namespace:

class MyDropout(keras.layers.Layer):
    def __init__(self, rate, name=None):
        super().__init__(name=name)
        self.rate = rate
        # Use seed_generator for managing RNG state.
        # It is a state element and its seed variable is
        # tracked as part of `layer.variables`.
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        # Use `keras_core.random` for random ops.
        return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

Next, let's write a custom subclassed model that uses our two custom layers:

class MyModel(keras.Model):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_base = keras.Sequential(
            [
                keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
                keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
                keras.layers.MaxPooling2D(pool_size=(2, 2)),
                keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
                keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
                keras.layers.GlobalAveragePooling2D(),
            ]
        )
        self.dp = MyDropout(0.5)
        self.dense = MyDense(num_classes, activation="softmax")

    def call(self, x):
        x = self.conv_base(x)
        x = self.dp(x)
        return self.dense(x)

Let's compile it and fit it:

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=1,  # For speed
    validation_split=0.15,
)
 399/399 ━━━━━━━━━━━━━━━━━━━━ 71s 176ms/step - acc: 0.5410 - loss: 1.2896 - val_acc: 0.9216 - val_loss: 0.2615

<keras_core.src.callbacks.history.History at 0x28a8047c0>

Training models on arbitrary data sources

All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you're using. This includes:

  • NumPy arrays
  • Pandas dataframes
  • TensorFlowtf.data.Dataset objects
  • PyTorch DataLoader objects
  • Keras PyDataset objects

They all work whether you're using TensorFlow, JAX, or PyTorch as your Keras backend.

Let's try it out with PyTorch DataLoaders:

import torch

# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_test), torch.from_numpy(y_test)
)

# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(
    train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_torch_dataset, batch_size=batch_size, shuffle=False
)

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)
 469/469 ━━━━━━━━━━━━━━━━━━━━ 82s 173ms/step - acc: 0.5713 - loss: 1.2036 - val_acc: 0.9319 - val_loss: 0.2262

<keras_core.src.callbacks.history.History at 0x2e41dac20>

Now let's try this out with tf.data:

import tensorflow as tf

train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
test_dataset = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)
 469/469 ━━━━━━━━━━━━━━━━━━━━ 80s 170ms/step - acc: 0.5450 - loss: 1.2778 - val_acc: 0.8986 - val_loss: 0.3166

<keras_core.src.callbacks.history.History at 0x2e467fdc0>

Further reading

This concludes our short overview of the new multi-backend capabilities of Keras Core. Next, you can learn about:

How to customize what happens in fit()

Want to implement a non-standard training algorithm yourself (e.g. a GAN training routine) but still want to benefit from the power and usability of fit()? It's really easy to customize fit() to support arbitrary use cases.


How to write custom training loops


How to distribute training

Enjoy the library! 🚀