Scikit-Learn APIのためのラッパー

keras.wrappers.scikit_learn.pyにあるラッパーを通して,KerasのSequentialモデル(1つの入力のみ)をScikit-Learnワークフローの一部として利用できます.

2つのラッパーが利用可能です:

keras.wrappers.sk_learn.KerasClassifier(build_fn=None, **sk_params), これはScikit-Learnのclassifierインターフェースを実装します.

keras.wrappers.sk_learn.KerasRegressor(build_fn=None, **sk_params), これはScikit-Learnのregressorインターフェースを実装します.

引数

  • build_fn: 呼び出し可能な関数,または,クラスインスタンス
  • sk_params: モデルパラメータとfittingパラメータ

build_fnは,Kerasモデルを構成し,コンパイルし,返します. このモデルは,fit/predictのために利用されます.以下の3つの値のうち 1つをbuild_fnに渡すことができます:

  1. 関数
  2. call メソッドを実装したクラスのインスタンス
  3. None.これはKerasClassifierまたはKerasRegressorを継承したクラスを意味します.この call メソッドはbuild_fnのデフォルトとして扱われます.

sk_paramsはモデルパラメータとfittingパラメータの両方を取ります. モデルパラメータはbuild_fnの引数です.sk_paramsに何も与えなくとも予測器が作れるように, scikit-learnの他の予測器と同様に,build_fnはその引数にデフォルトパラメータを取ります.

また,sk_paramsfitpredictpredict_proba,および,scoreメソッドを 呼ぶためのパラメータも取ります(例えば,epochs, batch_size). fitting (predicting) パラメータは以下の順番で選択されます:

  1. fitpredictpredict_proba,および,scoreメソッドの辞書引数に与えられた値
  2. sk_paramsに与えられた値
  3. keras.models.Sequentialfitpredictpredict_proba,および,scoreメソッドのデフォルト値

scikit-learnのgrid_searchAPIを利用するとき,チューニングパラメータはsk_paramsに渡したものになります. これには,fittingパラメータも含まれます.つまり,最適なモデルパラメータだけでなく,最適なbatch_sizeepochsの探索に,grid_searchを利用できます.