ยป Keras API reference / KerasTuner / Oracles / The base Oracle class

The base Oracle class

Oracle class

keras_tuner.Oracle(
    objective,
    max_trials=None,
    hyperparameters=None,
    allow_new_entries=True,
    tune_new_entries=True,
    seed=None,
)

Implements a hyperparameter optimization algorithm.

Arguments

  • objective: A string or keras_tuner.Objective instance. If a string, the direction of the optimization (min or max) will be inferred.
  • max_trials: Integer, the total number of trials (model configurations) to test at most. Note that the oracle may interrupt the search before max_trial models have been tested if the search space has been exhausted.
  • hyperparameters: Optional HyperParameters instance. Can be used to override (or register in advance) hyperparameters in the search space.
  • tune_new_entries: Boolean, whether hyperparameter entries that are requested by the hypermodel but that were not specified in hyperparameters should be added to the search space, or not. If not, then the default value for these parameters will be used. Defaults to True.
  • allow_new_entries: Boolean, whether the hypermodel is allowed to request hyperparameter entries not listed in hyperparameters. Defaults to True.
  • seed: Int. Random seed.

create_trial method

Oracle.create_trial(tuner_id)

Create a new Trial to be run by the Tuner.

A Trial corresponds to a unique set of hyperparameters to be run by Tuner.run_trial.

Arguments

  • tuner_id: A string, the ID that identifies the Tuner requesting a Trial. Tuners that should run the same trial (for instance, when running a multi-worker model) should have the same ID.

Returns

A Trial object containing a set of hyperparameter values to run in a Tuner.


end_trial method

Oracle.end_trial(trial_id, status="COMPLETED")

Record the measured objective for a set of parameter values.

Arguments

  • trial_id: A string, the unique ID for this trial.
  • status: A string, one of "COMPLETED", "INVALID". A status of "INVALID" means a trial has crashed or been deemed infeasible.

get_best_trials method

Oracle.get_best_trials(num_trials=1)

Returns the best Trials.


get_state method

Oracle.get_state()

Returns the current state of this object.

This method is called during save.

Returns

A dictionary of serializable objects as the state.


set_state method

Oracle.set_state(state)

Sets the current state of this object.

This method is called during reload.

Arguments

  • state: A dictionary of serialized objects as the state to restore.

score_trial method

Oracle.score_trial(trial)

Score a completed Trial.

This method can be overridden in subclasses to provide a score for a set of hyperparameter values. This method is called from end_trial on completed Trials.

Arguments

  • trial: A completed Trial object.

populate_space method

Oracle.populate_space(trial_id)

Fill the hyperparameter space with values for a trial.

This method should be overridden in subclasses and called in create_trial in order to populate the hyperparameter space with values.

Arguments

  • trial_id: A string, the ID for this Trial.

Returns

A dictionary with keys "values" and "status", where "values" is a mapping of parameter names to suggested values, and "status" is the TrialStatus that should be returned for this trial (one of "RUNNING", "IDLE", or "STOPPED").


update_trial method

Oracle.update_trial(trial_id, metrics, step=0)

Used by a worker to report the status of a trial.

Arguments

  • trial_id: A string, a previously seen trial id.
  • metrics: Dict of float. The current value of this trial's metrics.
  • step: Optional float, reporting intermediate results. The current value in a timeseries representing the state of the trial. This is the value that metrics will be associated with.

Returns

Trial object. Trial.status will be set to "STOPPED" if the Trial should be stopped early.