After five months of extensive public beta testing, we're excited to announce the official release of Keras 3.0. Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, or PyTorch, and that unlocks brand new large-scale model training and deployment capabilities. You can pick the framework that suits you best, and switch from one to another based on your current goals. You can also use Keras as a low-level cross-framework language to develop custom components such as layers, models, or metrics that can be used in native workflows in JAX, TensorFlow, or PyTorch — with one codebase.
Welcome to multi-framework machine learning.
You're already familiar with the benefits of using Keras — it enables high-velocity development via an obsessive focus on great UX, API design, and debuggability. It's also a battle-tested framework that has been chosen by over 2.5M developers and that powers some of the most sophisticated, largest-scale ML systems in the world, such as the Waymo self-driving fleet and the YouTube recommendation engine. But what are the additional benefits of using the new multi-backend Keras 3?
- Always get the best performance for your models. In our benchmarks, we found that JAX typically delivers the best training and inference performance on GPU, TPU, and CPU — but results vary from model to model, as non-XLA TensorFlow is occasionally faster on GPU. The ability to dynamically select the backend that will deliver the best performance for your model without having to change anything to your code means you're guaranteed to train and serve with the highest achievable efficiency.
- Unlock ecosystem optionality for your models. Any Keras 3
model can be instantiated as a PyTorch
Module, can be exported as a TensorFlow
SavedModel, or can be instantiated as a stateless JAX function. That means that you can use your Keras 3 models with PyTorch ecosystem packages, with the full range of TensorFlow deployment & production tools (like TF-Serving, TF.js and TFLite), and with JAX large-scale TPU training infrastructure. Write one
model.pyusing Keras 3 APIs, and get access to everything the ML world has to offer.
- Leverage large-scale model parallelism & data parallelism with JAX. Keras 3 includes
a brand new distribution API, the
keras.distributionnamespace, currently implemented for the JAX backend (coming soon to the TensorFlow and PyTorch backends). It makes it easy to do model parallelism, data parallelism, and combinations of both — at arbitrary model scales and cluster scales. Because it keeps the model definition, training logic, and sharding configuration all separate from each other, it makes your distribution workflow easy to develop and easy to maintain. See our starter guide.
- Maximize reach for your open-source model releases. Want to release a pretrained model? Want as many people as possible to be able to use it? If you implement it in pure TensorFlow or PyTorch, it will be usable by roughly half of the community. If you implement it in Keras 3, it is instantly usable by anyone regardless of their framework of choice (even if they're not Keras users themselves). Twice the impact at no added development cost.
- Use data pipelines from any source. The Keras 3
predict()routines are compatible with
tf.data.Datasetobjects, with PyTorch
DataLoaderobjects, with NumPy arrays, Pandas dataframes — regardless of the backend you're using. You can train a Keras 3 + TensorFlow model on a PyTorch
DataLoaderor train a Keras 3 + PyTorch model on a
The full Keras API, available for JAX, TensorFlow, and PyTorch.
Keras 3 implements the full Keras API and makes it available with TensorFlow, JAX, and PyTorch — over a hundred layers, dozens of metrics, loss functions, optimizers, and callbacks, the Keras training and evaluation loops, and the Keras saving & serialization infrastructure. All the APIs you know and love are here.
Any Keras model that only uses built-in layers will immediately work with
all supported backends. In fact, your existing
that only use built-in layers can start running in JAX and PyTorch right away!
That's right, your codebase just gained a whole new set of capabilities.
Author multi-framework layers, models, metrics...
Keras 3 enables you to create components
(like arbitrary custom layers or pretrained models) that will work the same
in any framework. In particular, Keras 3 gives you access
keras.ops namespace that works across all backends. It contains:
- A full implementation of the NumPy API.
Not something "NumPy-like" — just literally the
NumPy API, with the same functions and the same arguments.
- A set of neural network-specific functions that are absent from NumPy,
As long as you only use ops from
keras.ops, your custom layers,
custom losses, custom metrics, and custom optimizers
will work with JAX, PyTorch, and TensorFlow — with the same code.
That means that you can maintain only one
component implementation (e.g. a single
together with a single checkpoint file), and you can use it in all frameworks,
with the exact same numerics.
...that works seamlessly with any JAX, TensorFlow, and PyTorch workflow.
Keras 3 is not just intended for Keras-centric workflows
where you define a Keras model, a Keras optimizer, a Keras loss and metrics,
and you call
It's also meant to work seamlessly with low-level backend-native workflows:
you can take a Keras model (or any other component, such as a loss or metric)
and start using it in a JAX training loop, a TensorFlow training loop,
or a PyTorch training loop, or as part of a JAX or PyTorch model,
with zero friction. Keras 3 provides exactly
the same degree of low-level implementation flexibility in JAX and PyTorch
tf.keras previously did in TensorFlow.
- Write a low-level JAX training loop to train a Keras model
- Write a low-level TensorFlow training loop to train a Keras model
- Write a low-level PyTorch training loop to train a Keras model
torchloss function, and the
- Use Keras layers in a PyTorch
Module(because they are
- Use any PyTorch
Modulein a Keras model as if it were a Keras layer.
A new distribution API for large-scale data parallelism and model parallelism.
The models we've been working with have been getting larger and larger, so we wanted to provide a Kerasic solution to the multi-device model sharding problem. The API we designed keeps the model definition, the training logic, and the sharding configuration entirely separate from each other, meaning that your models can be written as if they were going to run on a single device. You can then add arbitrary sharding configurations to arbitrary models when it's time to train them.
Data parallelism (replicating a small model identically on multiple devices) can be handled in just two lines:
Model parallelism lets you specify sharding layouts for model variables and intermediate output tensors, along multiple named dimensions. In the typical case, you would organize available devices as a 2D grid (called a device mesh), where the first dimension is used for data parallelism and the second dimension is used for model parallelism. You would then configure your model to be sharded along the model dimension and replicated along the data dimension.
The API lets you configure the layout of every variable and every output tensor via regular expressions. This makes it easy to quickly specify the same layout for entire categories of variables.
The new distribution API is intended to be multi-backend, but is only available for the JAX backend for the time being. TensorFlow and PyTorch support is coming soon. Get started with this guide!
There's a wide range of pretrained models that you can start using today with Keras 3.
All 40 Keras Applications models (the
are available in all backends.
The vast array of pretrained models in KerasCV
and KerasNLP also work with all backends. This includes:
Support for cross-framework data pipelines with all backends.
Multi-framework ML also means multi-framework data loading and preprocessing. Keras 3 models can be trained using a wide range of data pipelines — regardless of whether you're using the JAX, PyTorch, or TensorFlow backends. It just works.
tf.data.Datasetpipelines: the reference for scalable production ML.
- NumPy arrays and Pandas dataframes.
- Keras's own
Progressive disclosure of complexity.
Progressive disclosure of complexity is the design principle at the heart of the Keras API. Keras doesn't force you to follow a single "true" way of building and training models. Instead, it enables a wide range of different workflows, from the very high-level to the very low-level, corresponding to different user profiles.
That means that you can start out with simple workflows — such as using
Functional models and training them with
fit() — and when
you need more flexibility, you can easily customize different components while
reusing most of your prior code. As your needs become more specific,
you don't suddenly fall off a complexity cliff and you don't need to switch
to a different set of tools.
We've brought this principle to all of our backends. For instance,
you can customize what happens in your training loop while still
leveraging the power of
fit(), without having to write your own training loop
from scratch — just by overriding the
Here's how it works in PyTorch and TensorFlow:
And here's the link to the JAX version.
A new stateless API for layers, models, metrics, and optimizers.
Do you enjoy functional programming? You're in for a treat.
All stateful objects in Keras (i.e. objects that own numerical variables that get updated during training or evaluation) now have a stateless API, making it possible to use them in JAX functions (which are required to be fully stateless):
- All layers and models have a
stateless_call()method which mirrors
- All optimizers have a
stateless_apply()method which mirrors
- All metrics have a
stateless_update_state()method which mirrors
stateless_result()method which mirrors
These methods have no side-effects whatsoever: they take as input the current value of the state variables of the target object, and return the update values as part of their outputs, e.g.:
outputs, updated_non_trainable_variables = layer.stateless_call(
You never have to implement these methods yourself — they're automatically available
as long as you've implemented the stateful version (e.g.
Moving from Keras 2 to Keras 3
Keras 3 is highly backwards compatible with Keras 2: it implements the full public API surface of Keras 2, with a limited number of exceptions, listed here. Most users will not have to make any code change to start running their Keras scripts on Keras 3.
Larger codebases are likely to require some code changes,
since they are more likely to run into one of the exceptions listed above,
and are more likely to have been using private APIs or deprecated APIs
keras.src private namespace).
To help you move to Keras 3, we are releasing a complete migration guide
with quick fixes for all issues you might encounter.
You also have the option to ignore the changes in Keras 3 and just keep using Keras 2 with TensorFlow — this can be a good option for projects that are not actively developed but need to keep running with updated dependencies. You have two possibilities:
- If you were accessing
kerasas a standalone package, just switch to using the Python package
tf_kerasinstead, which you can install via
pip install tf_keras. The code and API are wholly unchanged — it's Keras 2.15 with a different package name. We will keep fixing bugs in
tf_kerasand we will keep regularly releasing new versions. However, no new features or performance improvements will be added, since the package is now in maintenance mode.
- If you were accessing
tf.keras, there are no immediate changes until TensorFlow 2.16. TensorFlow 2.16+ will use Keras 3 by default. In TensorFlow 2.16+, to keep using Keras 2, you can first install
tf_keras, and then export the environment variable
TF_USE_LEGACY_KERAS=1. This will direct TensorFlow 2.16+ to resolve tf.keras to the locally-installed
tf_keraspackage. Note that this may affect more than your own code, however: it will affect any package importing
tf.kerasin your Python process. To make sure your changes only affect your own code, you should use the
Enjoy the library!
We're excited for you to try out the new Keras and improve your workflows by leveraging multi-framework ML. Let us know how it goes: issues, points of friction, feature requests, or success stories — we're eager to hear from you!
Q: Is Keras 3 compatible with legacy Keras 2?
Code developed with
tf.keras can generally be run as-is with Keras 3
(with the TensorFlow backend). There's a limited number of incompatibilities you should be mindful
of, all addressed in this migration guide.
When it comes to using APIs from
tf.keras and Keras 3 side by side,
that is not possible — they're different packages, running on entirely separate engines.
Q: Do pretrained models developed in legacy Keras 2 work with Keras 3?
Generally, yes. Any
tf.keras model should work out of the box with Keras 3
with the TensorFlow backend (make sure to save it in the
.keras v3 format).
In addition, if the model only
uses built-in Keras layers, then it will also work out of the box
with Keras 3 with the JAX and PyTorch backends.
If the model contains custom layers written using TensorFlow APIs,
it is usually easy to convert the code to be backend-agnostic.
For instance, it only took us a few hours to convert all 40
tf.keras models from Keras Applications to be backend-agnostic.
Q: Can I save a Keras 3 model in one backend and reload it in another backend?
Yes, you can. There is no backend specialization in saved
.keras files whatsoever.
Your saved Keras models are framework-agnostic and can be reloaded with any backend.
However, note that reloading a model that contains custom components
with a different backend requires your custom components to be implemented
using backend-agnostic APIs, e.g.
Q: Can I use Keras 3 components inside
With the TensorFlow backend, Keras 3 is fully compatible with
(e.g. you can
Sequential model into a
With a different backend, Keras 3 has limited support for
You won't be able to
.map() arbitrary layers or models into a
pipeline. However, you will be able to use specific Keras 3
preprocessing layers with
tf.data, such as
When it comes to using a
tf.data pipeline (that does not use Keras)
to feed your call to
that works out of the box with all backends.
Q: Do Keras 3 models behave the same when run with different backends?
Yes, numerics are identical across backends. However, keep in mind the following caveats:
- RNG behavior is different across different backends (even after seeding — your results will be deterministic in each backend but will differ across backends). So random weight initializations values and dropout values will differ across backends.
- Due to the nature of floating-point implementations,
results are only identical up to
1e-7precision in float32, per function execution. So when training a model for a long time, small numerical differences will accumulate and may end up resulting in noticeable numerical differences.
- Due to lack of support for average pooling with asymmetric padding
in PyTorch, average pooling layers with
padding="same"may result in different numerics on border rows/columns. This doesn't happen very often in practice — out of 40 Keras Applications vision models, only one was affected.
Q: Does Keras 3 support distributed training?
Data-parallel distribution is supported out of the box in JAX, TensorFlow,
and PyTorch. Model parallel distribution is supported out of the box for JAX
Keras 3 is compatible with
just open a Distribution Strategy scope and create / train your model within it.
Here's an example.
Keras 3 is compatible with PyTorch's
Here's an example.
You can do both data parallel and model parallel distribution in JAX using the
For instance, to do data parallel distribution, you only need the following code snippet:
distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
For model parallel distribution, see the following guide.
You can also distribute training yourself via JAX APIs such as
jax.sharding. Here's an example.
Q: Can my custom Keras layers be used in native PyTorch
Modules or with Flax
If they are only written using Keras APIs (e.g. the
keras.ops namespace), then yes, your
Keras layers will work out of the box with native PyTorch and JAX code.
In PyTorch, just use your Keras layer like any other PyTorch
In JAX, make sure to use the stateless layer API, i.e.
Q: Will you add more backends in the future? What about framework XYZ?
We're open to adding new backends as long as the target framework has a large user base or otherwise has some unique technical benefits to bring to the table. However, adding and maintaining a new backend is a large burden, so we're going to carefully consider each new backend candidate on a case by case basis, and we're not likely to add many new backends. We will not add any new frameworks that aren't yet well-established. We are now potentially considering adding a backend written in Mojo. If that's something you might find useful, please let the Mojo team know.