Keras 3 API documentation / Utilities / Scikit-Learn API wrappers

Scikit-Learn API wrappers


SKLearnClassifier class

    model, warm_start=False, model_kwargs=None, fit_kwargs=None

scikit-learn compatible classifier wrapper for Keras models.

Note that there are sources of randomness in model initialization and training. Refer to Reproducibility in Keras Models on how to control randomness.


  • model: Model. An instance of Model, or a callable returning such an object. Note that if input is a Model, it will be cloned using keras.models.clone_model before being fitted, unless warm_start=True. The Model instance needs to be passed as already compiled. If callable, it must accept at least X and y as keyword arguments. Other arguments must be accepted if passed as model_kwargs by the user.
  • warm_start: bool, defaults to False. Whether to reuse the model weights from the previous fit. If True, the given model won't be cloned and the weights from the previous fit will be reused.
  • model_kwargs: dict, defaults to None. Keyword arguments passed to model, if model is callable.
  • fit_kwargs: dict, defaults to None. Keyword arguments passed to These can also be passed directly to the fit method of the scikit-learn wrapper. The values passed directly to the fit method take precedence over these.


  • model_ : Model The fitted model.
  • history_ : dict The history of the fit, returned by
  • classes_ : array-like, shape=(n_classes,) The classes labels.


Here we use a function which creates a basic MLP model dynamically choosing the input and output shapes. We will use this to create our scikit-learn model.

from keras.src.layers import Dense, Input, Model

def dynamic_model(X, y, loss, layers=[10]):
    # Creates a basic MLP model dynamically choosing the input and
    # output shapes.
    n_features_in = X.shape[1]
    inp = Input(shape=(n_features_in,))

    hidden = inp
    for layer_size in layers:
        hidden = Dense(layer_size, activation="relu")(hidden)

    n_outputs = y.shape[1] if len(y.shape) > 1 else 1
    out = [Dense(n_outputs, activation="softmax")(hidden)]
    model = Model(inp, out)
    model.compile(loss=loss, optimizer="rmsprop")

    return model

You can then use this function to create a scikit-learn compatible model and fit it on some data.

from sklearn.datasets import make_classification
from keras.wrappers import SKLearnClassifier

X, y = make_classification(n_samples=1000, n_features=10, n_classes=3)
est = SKLearnClassifier(
        "loss": "categorical_crossentropy",
        "layers": [20, 20, 20],
), y, epochs=5)


SKLearnRegressor class

    model, warm_start=False, model_kwargs=None, fit_kwargs=None

scikit-learn compatible regressor wrapper for Keras models.

Note that there are sources of randomness in model initialization and training. Refer to Reproducibility in Keras Models on how to control randomness.


  • model: Model. An instance of Model, or a callable returning such an object. Note that if input is a Model, it will be cloned using keras.models.clone_model before being fitted, unless warm_start=True. The Model instance needs to be passed as already compiled. If callable, it must accept at least X and y as keyword arguments. Other arguments must be accepted if passed as model_kwargs by the user.
  • warm_start: bool, defaults to False. Whether to reuse the model weights from the previous fit. If True, the given model won't be cloned and the weights from the previous fit will be reused.
  • model_kwargs: dict, defaults to None. Keyword arguments passed to model, if model is callable.
  • fit_kwargs: dict, defaults to None. Keyword arguments passed to These can also be passed directly to the fit method of the scikit-learn wrapper. The values passed directly to the fit method take precedence over these.


  • model_ : Model The fitted model.


Here we use a function which creates a basic MLP model dynamically choosing the input and output shapes. We will use this to create our scikit-learn model.

from keras.src.layers import Dense, Input, Model

def dynamic_model(X, y, loss, layers=[10]):
    # Creates a basic MLP model dynamically choosing the input and
    # output shapes.
    n_features_in = X.shape[1]
    inp = Input(shape=(n_features_in,))

    hidden = inp
    for layer_size in layers:
        hidden = Dense(layer_size, activation="relu")(hidden)

    n_outputs = y.shape[1] if len(y.shape) > 1 else 1
    out = [Dense(n_outputs, activation="softmax")(hidden)]
    model = Model(inp, out)
    model.compile(loss=loss, optimizer="rmsprop")

    return model

You can then use this function to create a scikit-learn compatible model and fit it on some data.

from sklearn.datasets import make_regression
from keras.wrappers import SKLearnRegressor

X, y = make_regression(n_samples=1000, n_features=10)
est = SKLearnRegressor(
        "loss": "mse",
        "layers": [20, 20, 20],
), y, epochs=5)


SKLearnTransformer class

    model, warm_start=False, model_kwargs=None, fit_kwargs=None

scikit-learn compatible transformer wrapper for Keras models.

Note that this is a scikit-learn compatible transformer, and not a transformer in the deep learning sense.

Also note that there are sources of randomness in model initialization and training. Refer to Reproducibility in Keras Models on how to control randomness.


  • model: Model. An instance of Model, or a callable returning such an object. Note that if input is a Model, it will be cloned using keras.models.clone_model before being fitted, unless warm_start=True. The Model instance needs to be passed as already compiled. If callable, it must accept at least X and y as keyword arguments. Other arguments must be accepted if passed as model_kwargs by the user.
  • warm_start: bool, defaults to False. Whether to reuse the model weights from the previous fit. If True, the given model won't be cloned and the weights from the previous fit will be reused.
  • model_kwargs: dict, defaults to None. Keyword arguments passed to model, if model is callable.
  • fit_kwargs: dict, defaults to None. Keyword arguments passed to These can also be passed directly to the fit method of the scikit-learn wrapper. The values passed directly to the fit method take precedence over these.


  • model_ : Model The fitted model.
  • history_ : dict The history of the fit, returned by


A common use case for a scikit-learn transformer, is to have a step which gives you the embedding of your data. Here we assume my_package.my_model is a Keras model which takes the input and gives embeddings of the data, and my_package.my_data is your dataset loader.

from my_package import my_model, my_data
from keras.wrappers import SKLearnTransformer
from sklearn.frozen import FrozenEstimator # requires scikit-learn>=1.6
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import HistGradientBoostingClassifier

X, y = my_data()

trs = FrozenEstimator(SKLearnTransformer(model=my_model))
pipe = make_pipeline(trs, HistGradientBoostingClassifier()), y)

Note that in the above example, FrozenEstimator prevents any further training of the transformer step in the pipeline, which can be the case if you don't want to change the embedding model at hand.