St_Hakky’s blog

Data Science / Human Resources / Web Applicationについて書きます

kerasでmultiple (複数の) 入力 / 出力 / 損失関数を扱う時のTipsをまとめる

こんにちは。

〇この記事のモチベーション

Deep Learningで自分でモデルとかを作ろうとすると、複数の入力や出力、そして損失関数を取扱たくなる時期が必ず来ると思います。

最近では、GoogleNetとかは中間層の途中で出力を出していたりするので、そういうのでも普通に遭遇します。

というわけで私も例に漏れず遭遇しました笑。

今回はkerasで複数の入力や出力、そして損失関数を取り扱うときにどうすればいいかについて実践したのでまとめておきます。

〇「複数の入力」を与えたい場合

これは簡単です。普段Modelのインスタンスを作る際に、inputsとoutputsを指定すると思いますが、その際に複数ある場合はリスト形式で渡せばいいだけです。

input_layer1 = Input(shape=(32,))
input_layer2 = Input(shape=(64,))

# ...(モデルの詳細を書く)...

model = Model(inputs=[input_layer1, input_layer2],
                         outputs=output_layer)

〇「複数の出力」を与えたい場合

このパターンを行う場合、損失関数をどう割り当てるかで場合がわかれますので、それについて書きます。

■「複数の出力」に対して、「そのそれぞれに同一の損失関数」を与えたい場合

これは簡単で、複数の入力の場合と同様に、以下のように設定するだけです。

model = Model(inputs=input_layer,
                        outputs=[output_layer1, output_layer2])
model.compile(optimizer='sgd',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])
■「複数の出力」に対して、「そのそれぞれに別の損失関数」を与えたい場合

さて、このパターンが結構複雑ってこともないんですが、色々やろうとした時につまづいたポイントです(特にオリジナルな損失関数とかを使おうとした時に自分ははまりました汗)

このパターンであるのは、以下のようなシチュエーションです。

  • 複数の出力で且つそのそれぞれに対して、別の損失関数を割り当てたい場合
  • それぞれの損失関数同士の比重、つまりどれくらい重要視するかを変えたい場合

論文にあるようなものとしては、マルチタスク学習とかは上のパターンに遭遇しやすいかなと思います。この実装をどうすればいいかですが、こちらのソースコード を見るか(514行目くらいから)、ドキュメントを見ると、わかると思います。

model = Model(inputs=input_layer,
                         outputs=[output_layer1, output_layer2])
model.compile(optimizer='sgd',
                        loss={'categorical_crossentropy': 'output_layer1', 'mse': 'output_layer2'},
                        loss_weights={'output_layer1': 0.1, 'output_layer2':0.7},
                        metrics=['accuracy'])

kerasではlossに値で渡した場合、辞書で渡した場合、そしてリストで渡した場合、のそれぞれについて、別のこととしてサポートしてくれているみたいですね。

このことをまとめると以下の様になると思います。

  • 値:一つのloss関数のみが適用される
  • リスト:複数のloss関数を、複数の出力のそれぞれに対して適用する
  • 辞書:ユーザーが辞書式で指定した、出力層の名前と対応するloss関数を出力層に対して適用する

ちなみに、'output_layer1'とかはlayerの名前です。Denseとかを使って、層を定義するときに name引数に対して値を指定すると層の名前を指定できると思いますが、その名前です。これを指定してなかったり、指定したけど違う名前にしていたとかだと、当たり前ですが、実行できないので注意が必要です。

こんなもんですかね。それでは。