Многоканальный выход Keras: функция пользовательских потерь
Я использую модель множественного вывода в керасе
model1 = Model(input=x, output=[y2,y3])
model1.compile((optimizer='sgd', loss=cutom_loss_function)
моя custom_loss_function
;
def custom_loss(y_true, y_pred):
y2_pred = y_pred[0]
y2_true = y_true[0]
loss = K.mean(K.square(y2_true - y2_pred), axis=-1)
return loss
Я хочу только обучать сеть на выходе y2
.
Какова форма/структура аргументов y_pred
и y_true
в функции потерь при использовании нескольких выходов? Могу ли я получить к ним доступ, как указано выше? Это y_pred[0]
или y_pred[:,0]
?
Ответы
Ответ 1
Я хочу только обучать сеть на выходе y2.
На основе функционального API-интерфейса Keras вы можете достичь этого с помощью
model1 = Model(input=x, output=[y2,y3])
model1.compile(optimizer='sgd', loss=custom_loss_function,
loss_weights=[1., 0.0])
Какова форма/структура аргументов y_pred и y_true в функции потерь при использовании нескольких выходов? Могу ли я получить к ним доступ, как указано выше? Это y_pred [0] или y_pred [:, 0]
В keras multi-output модели функция потерь применяется для каждого выхода отдельно. В псевдокоде:
loss = sum( [ loss_function( output_true, output_pred ) for ( output_true, output_pred ) in zip( outputs_data, outputs_model ) ] )
Функциональность функции потери на нескольких выходах кажется мне недоступной. Вероятно, это может быть достигнуто за счет включения функции потерь в качестве уровня сети.
Ответ 2
Ответ Шараполаса правильный.
Однако есть лучший способ, чем использовать слой для построения пользовательских функций потерь со сложной взаимозависимостью нескольких выходных данных модели.
Метод, который я знаю, используется на практике - никогда не вызывать model.compile
а только model._make_predict_function()
. model.output
вы можете продолжить и создать собственный метод оптимизатора, вызвав там model.output
. Это даст вам все выходные данные, [y2, y3] в вашем случае. Когда вы делаете с ним свою магию, возьмите keras.optimizer
и используйте его метод get_update, используя ваш model.trainable_weights и ваш проигрыш. Наконец, верните функцию keras.function
со списком необходимых входных данных (в вашем случае только model.input
) и обновления, которые вы только что получили из вызова optimizer.get_update. Эта функция теперь заменяет model.fit.
Вышеуказанное часто используется в алгоритмах PolicyGradient, таких как A3C или PPO. Вот пример того, что я пытался объяснить: https://github.com/Hyeokreal/Actor-Critic-Continuous-Keras/blob/master/a2c_continuous.py Посмотрите на методы build_model и crit_optimizer и прочитайте документацию kreas.backend.function. чтобы понять, что происходит.
Я обнаружил, что у этого способа часто бывают проблемы с управлением сессиями, и в настоящее время он не работает в tf-2.0 keras вообще. Следовательно, если кто-нибудь знает метод, пожалуйста, дайте мне знать. Я пришел сюда в поисках одного :)