Потеря триэтки объекта с керасом
Я пытаюсь реализовать facenet в Keras с back-end Thensorflow, и у меня есть некоторая проблема с потерей триплета. ![введите описание изображения здесь]()
Я вызываю функцию fit с 3 * n количеством изображений, а затем я определяю свою пользовательскую функцию потерь следующим образом:
def triplet_loss(self, y_true, y_pred):
embeddings = K.reshape(y_pred, (-1, 3, output_dim))
positive_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,1]),axis=-1)
negative_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,2]),axis=-1)
return K.mean(K.maximum(0.0, positive_distance - negative_distance + _alpha))
self._model.compile(loss=triplet_loss, optimizer="sgd")
self._model.fit(x=x,y=y,nb_epoch=1, batch_size=len(x))
где y - только фиктивный массив, заполненный 0s
Проблема заключается в том, что даже после первой итерации с размером партии 20 модель начинает прогнозировать одно и то же вложение для всех изображений. Поэтому, когда я сначала делаю предсказание по партии, каждое вложение отличается. Затем я делаю подгонку и предсказываю снова, и вдруг все вложения становятся почти одинаковыми для всех изображений в партии
Также обратите внимание, что в конце модели есть слой Лямбда. Он нормализует выход сети, поэтому все вложения имеют единичную длину, как это было предложено в исследовании на лицевой панели.
Может кто-нибудь помочь мне здесь?
Резюме модели
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D) (None, 112, 112, 64) 9472 input_1[0][0]
____________________________________________________________________________________________________
batchnormalization_1 (BatchNormal(None, 112, 112, 64) 128 convolution2d_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D) (None, 56, 56, 64) 0 batchnormalization_1[0][0]
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D) (None, 56, 56, 64) 4160 maxpooling2d_1[0][0]
____________________________________________________________________________________________________
batchnormalization_2 (BatchNormal(None, 56, 56, 64) 128 convolution2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D) (None, 56, 56, 192) 110784 batchnormalization_2[0][0]
____________________________________________________________________________________________________
batchnormalization_3 (BatchNormal(None, 56, 56, 192) 384 convolution2d_3[0][0]
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D) (None, 28, 28, 192) 0 batchnormalization_3[0][0]
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D) (None, 28, 28, 96) 18528 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_7 (Convolution2D) (None, 28, 28, 16) 3088 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
maxpooling2d_3 (MaxPooling2D) (None, 28, 28, 192) 0 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D) (None, 28, 28, 64) 12352 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_6 (Convolution2D) (None, 28, 28, 128) 110720 convolution2d_5[0][0]
____________________________________________________________________________________________________
convolution2d_8 (Convolution2D) (None, 28, 28, 32) 12832 convolution2d_7[0][0]
____________________________________________________________________________________________________
convolution2d_9 (Convolution2D) (None, 28, 28, 32) 6176 maxpooling2d_3[0][0]
____________________________________________________________________________________________________
merge_1 (Merge) (None, 28, 28, 256) 0 convolution2d_4[0][0]
convolution2d_6[0][0]
convolution2d_8[0][0]
convolution2d_9[0][0]
____________________________________________________________________________________________________
convolution2d_11 (Convolution2D) (None, 28, 28, 96) 24672 merge_1[0][0]
____________________________________________________________________________________________________
convolution2d_13 (Convolution2D) (None, 28, 28, 32) 8224 merge_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_4 (MaxPooling2D) (None, 28, 28, 256) 0 merge_1[0][0]
____________________________________________________________________________________________________
convolution2d_10 (Convolution2D) (None, 28, 28, 64) 16448 merge_1[0][0]
____________________________________________________________________________________________________
convolution2d_12 (Convolution2D) (None, 28, 28, 128) 110720 convolution2d_11[0][0]
____________________________________________________________________________________________________
convolution2d_14 (Convolution2D) (None, 28, 28, 64) 51264 convolution2d_13[0][0]
____________________________________________________________________________________________________
convolution2d_15 (Convolution2D) (None, 28, 28, 64) 16448 maxpooling2d_4[0][0]
____________________________________________________________________________________________________
merge_2 (Merge) (None, 28, 28, 320) 0 convolution2d_10[0][0]
convolution2d_12[0][0]
convolution2d_14[0][0]
convolution2d_15[0][0]
____________________________________________________________________________________________________
convolution2d_16 (Convolution2D) (None, 28, 28, 128) 41088 merge_2[0][0]
____________________________________________________________________________________________________
convolution2d_18 (Convolution2D) (None, 28, 28, 32) 10272 merge_2[0][0]
____________________________________________________________________________________________________
convolution2d_17 (Convolution2D) (None, 14, 14, 256) 295168 convolution2d_16[0][0]
____________________________________________________________________________________________________
convolution2d_19 (Convolution2D) (None, 14, 14, 64) 51264 convolution2d_18[0][0]
____________________________________________________________________________________________________
maxpooling2d_5 (MaxPooling2D) (None, 14, 14, 320) 0 merge_2[0][0]
____________________________________________________________________________________________________
merge_3 (Merge) (None, 14, 14, 640) 0 convolution2d_17[0][0]
convolution2d_19[0][0]
maxpooling2d_5[0][0]
____________________________________________________________________________________________________
convolution2d_21 (Convolution2D) (None, 14, 14, 96) 61536 merge_3[0][0]
____________________________________________________________________________________________________
convolution2d_23 (Convolution2D) (None, 14, 14, 32) 20512 merge_3[0][0]
____________________________________________________________________________________________________
maxpooling2d_6 (MaxPooling2D) (None, 14, 14, 640) 0 merge_3[0][0]
____________________________________________________________________________________________________
convolution2d_20 (Convolution2D) (None, 14, 14, 256) 164096 merge_3[0][0]
____________________________________________________________________________________________________
convolution2d_22 (Convolution2D) (None, 14, 14, 192) 166080 convolution2d_21[0][0]
____________________________________________________________________________________________________
convolution2d_24 (Convolution2D) (None, 14, 14, 64) 51264 convolution2d_23[0][0]
____________________________________________________________________________________________________
convolution2d_25 (Convolution2D) (None, 14, 14, 128) 82048 maxpooling2d_6[0][0]
____________________________________________________________________________________________________
merge_4 (Merge) (None, 14, 14, 640) 0 convolution2d_20[0][0]
convolution2d_22[0][0]
convolution2d_24[0][0]
convolution2d_25[0][0]
____________________________________________________________________________________________________
convolution2d_27 (Convolution2D) (None, 14, 14, 112) 71792 merge_4[0][0]
____________________________________________________________________________________________________
convolution2d_29 (Convolution2D) (None, 14, 14, 32) 20512 merge_4[0][0]
____________________________________________________________________________________________________
maxpooling2d_7 (MaxPooling2D) (None, 14, 14, 640) 0 merge_4[0][0]
____________________________________________________________________________________________________
convolution2d_26 (Convolution2D) (None, 14, 14, 224) 143584 merge_4[0][0]
____________________________________________________________________________________________________
convolution2d_28 (Convolution2D) (None, 14, 14, 224) 226016 convolution2d_27[0][0]
____________________________________________________________________________________________________
convolution2d_30 (Convolution2D) (None, 14, 14, 64) 51264 convolution2d_29[0][0]
____________________________________________________________________________________________________
convolution2d_31 (Convolution2D) (None, 14, 14, 128) 82048 maxpooling2d_7[0][0]
____________________________________________________________________________________________________
merge_5 (Merge) (None, 14, 14, 640) 0 convolution2d_26[0][0]
convolution2d_28[0][0]
convolution2d_30[0][0]
convolution2d_31[0][0]
____________________________________________________________________________________________________
convolution2d_33 (Convolution2D) (None, 14, 14, 128) 82048 merge_5[0][0]
____________________________________________________________________________________________________
convolution2d_35 (Convolution2D) (None, 14, 14, 32) 20512 merge_5[0][0]
____________________________________________________________________________________________________
maxpooling2d_8 (MaxPooling2D) (None, 14, 14, 640) 0 merge_5[0][0]
____________________________________________________________________________________________________
convolution2d_32 (Convolution2D) (None, 14, 14, 192) 123072 merge_5[0][0]
____________________________________________________________________________________________________
convolution2d_34 (Convolution2D) (None, 14, 14, 256) 295168 convolution2d_33[0][0]
____________________________________________________________________________________________________
convolution2d_36 (Convolution2D) (None, 14, 14, 64) 51264 convolution2d_35[0][0]
____________________________________________________________________________________________________
convolution2d_37 (Convolution2D) (None, 14, 14, 128) 82048 maxpooling2d_8[0][0]
____________________________________________________________________________________________________
merge_6 (Merge) (None, 14, 14, 640) 0 convolution2d_32[0][0]
convolution2d_34[0][0]
convolution2d_36[0][0]
convolution2d_37[0][0]
____________________________________________________________________________________________________
convolution2d_39 (Convolution2D) (None, 14, 14, 144) 92304 merge_6[0][0]
____________________________________________________________________________________________________
convolution2d_41 (Convolution2D) (None, 14, 14, 32) 20512 merge_6[0][0]
____________________________________________________________________________________________________
maxpooling2d_9 (MaxPooling2D) (None, 14, 14, 640) 0 merge_6[0][0]
____________________________________________________________________________________________________
convolution2d_38 (Convolution2D) (None, 14, 14, 160) 102560 merge_6[0][0]
____________________________________________________________________________________________________
convolution2d_40 (Convolution2D) (None, 14, 14, 288) 373536 convolution2d_39[0][0]
____________________________________________________________________________________________________
convolution2d_42 (Convolution2D) (None, 14, 14, 64) 51264 convolution2d_41[0][0]
____________________________________________________________________________________________________
convolution2d_43 (Convolution2D) (None, 14, 14, 128) 82048 maxpooling2d_9[0][0]
____________________________________________________________________________________________________
merge_7 (Merge) (None, 14, 14, 640) 0 convolution2d_38[0][0]
convolution2d_40[0][0]
convolution2d_42[0][0]
convolution2d_43[0][0]
____________________________________________________________________________________________________
convolution2d_44 (Convolution2D) (None, 14, 14, 160) 102560 merge_7[0][0]
____________________________________________________________________________________________________
convolution2d_46 (Convolution2D) (None, 14, 14, 64) 41024 merge_7[0][0]
____________________________________________________________________________________________________
convolution2d_45 (Convolution2D) (None, 7, 7, 256) 368896 convolution2d_44[0][0]
____________________________________________________________________________________________________
convolution2d_47 (Convolution2D) (None, 7, 7, 128) 204928 convolution2d_46[0][0]
____________________________________________________________________________________________________
maxpooling2d_10 (MaxPooling2D) (None, 7, 7, 640) 0 merge_7[0][0]
____________________________________________________________________________________________________
merge_8 (Merge) (None, 7, 7, 1024) 0 convolution2d_45[0][0]
convolution2d_47[0][0]
maxpooling2d_10[0][0]
____________________________________________________________________________________________________
convolution2d_49 (Convolution2D) (None, 7, 7, 192) 196800 merge_8[0][0]
____________________________________________________________________________________________________
convolution2d_51 (Convolution2D) (None, 7, 7, 48) 49200 merge_8[0][0]
____________________________________________________________________________________________________
maxpooling2d_11 (MaxPooling2D) (None, 7, 7, 1024) 0 merge_8[0][0]
____________________________________________________________________________________________________
convolution2d_48 (Convolution2D) (None, 7, 7, 384) 393600 merge_8[0][0]
____________________________________________________________________________________________________
convolution2d_50 (Convolution2D) (None, 7, 7, 384) 663936 convolution2d_49[0][0]
____________________________________________________________________________________________________
convolution2d_52 (Convolution2D) (None, 7, 7, 128) 153728 convolution2d_51[0][0]
____________________________________________________________________________________________________
convolution2d_53 (Convolution2D) (None, 7, 7, 128) 131200 maxpooling2d_11[0][0]
____________________________________________________________________________________________________
merge_9 (Merge) (None, 7, 7, 1024) 0 convolution2d_48[0][0]
convolution2d_50[0][0]
convolution2d_52[0][0]
convolution2d_53[0][0]
____________________________________________________________________________________________________
convolution2d_55 (Convolution2D) (None, 7, 7, 192) 196800 merge_9[0][0]
____________________________________________________________________________________________________
convolution2d_57 (Convolution2D) (None, 7, 7, 48) 49200 merge_9[0][0]
____________________________________________________________________________________________________
maxpooling2d_12 (MaxPooling2D) (None, 7, 7, 1024) 0 merge_9[0][0]
____________________________________________________________________________________________________
convolution2d_54 (Convolution2D) (None, 7, 7, 384) 393600 merge_9[0][0]
____________________________________________________________________________________________________
convolution2d_56 (Convolution2D) (None, 7, 7, 384) 663936 convolution2d_55[0][0]
____________________________________________________________________________________________________
convolution2d_58 (Convolution2D) (None, 7, 7, 128) 153728 convolution2d_57[0][0]
____________________________________________________________________________________________________
convolution2d_59 (Convolution2D) (None, 7, 7, 128) 131200 maxpooling2d_12[0][0]
____________________________________________________________________________________________________
merge_10 (Merge) (None, 7, 7, 1024) 0 convolution2d_54[0][0]
convolution2d_56[0][0]
convolution2d_58[0][0]
convolution2d_59[0][0]
____________________________________________________________________________________________________
averagepooling2d_1 (AveragePoolin(None, 1, 1, 1024) 0 merge_10[0][0]
____________________________________________________________________________________________________
flatten_1 (Flatten) (None, 1024) 0 averagepooling2d_1[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 128) 131200 flatten_1[0][0]
____________________________________________________________________________________________________
lambda_1 (Lambda) (None, 128) 0 dense_1[0][0]
====================================================================================================
Total params: 7456944
____________________________________________________________________________________________________
None
Ответы
Ответ 1
Что могло произойти, кроме того, что скорость обучения была просто слишком высокой, так это то, что эффективно использовалась нестабильная стратегия отбора триплетов. Если, например, вы используете только "жесткие триплеты" (триплеты, где расстояние меньше, чем расстояние ap), вес вашей сети может свести все вложения в одну точку (делая потерю всегда равной margin (your _alpha
), потому что все расстояния вложения равны нулю).
Это можно исправить, используя также другие виды триплетов (например, "полутвердые триплеты", где ap меньше, чем an, но расстояние между ap и an все же меньше, чем запас). Так что, может быть, если вы всегда проверяли это... Это более подробно объясняется в этом сообщении в блоге: https://omoindrot.github.io/triplet-loss
Ответ 2
Вы ограничиваете свои вложения, чтобы быть "в d-мерной гиперсфере"? Попробуйте запустить tf.nn.l2_normalize
для ваших вложений сразу после их выхода из CNN.
Проблема может заключаться в том, что вложения являются своего рода умными алеками. Один простой способ уменьшить потери - просто установить все на ноль. l2_normalize
заставляет их быть на единицу длины.
Похоже, вы захотите добавить нормализацию сразу после последнего среднего пула.
Ответ 3
Я столкнулся с той же проблемой, и я провел некоторую исследовательскую работу. Я думаю, это потому, что потеря триплета требует нескольких входных сигналов, что может привести к тому, что сеть будет генерировать такие результаты. Я еще не исправил проблему, но вы можете проверить страницу проблем keras для более подробной информации https://github.com/keras-team/keras/issues/9498.
На странице проблемы я реализовал поддельный набор данных и фальшивую потерю триплета, чтобы решить проблему, после того как я изменил структуру ввода в сети, потеря становится нормальной.
Ответ 4
функция потерь в тензорном потоке требует списка меток, то есть списка целых чисел. Я думаю, что вы передаете 2D-матрицу, то есть одну горячую кодировку.
Попробуй это
import keras.backend as K
from tf.contrib.losses.metric_learning import triplet_semihard_loss
def loss(y_true, y_pred):
y_true = K.argmax(y_true, axis = -1)
return triplet_semihard_loss(labels=y_true, embeddings=y_pred, margin=1.)