Ответ 1
Вам нужно реализовать себе контрастирующую потерю или потерю триплета, но как только вы знаете пары или триплеты, это довольно легко.
Контрастная потеря
Предположим, что у вас есть входные пары данных и их ярлык (положительный или отрицательный, т.е. тот же класс или другой класс). Например, у вас есть изображения в виде размера 28x28x1:
left = tf.placeholder(tf.float32, [None, 28, 28, 1])
right = tf.placeholder(tf.float32, [None, 28, 28, 1])
label = tf.placeholder(tf.int32, [None, 1]). # 0 if same, 1 if different
margin = 0.2
left_output = model(left) # shape [None, 128]
right_output = model(right) # shape [None, 128]
d = tf.reduce_sum(tf.square(left_output - right_output), 1)
d_sqrt = tf.sqrt(d)
loss = label * tf.square(tf.maximum(0., margin - d_sqrt)) + (1 - label) * d
loss = 0.5 * tf.reduce_mean(loss)
Потеря триплета
То же, что и с контрастирующей потерей, но с тройками (якорь, положительный, отрицательный). Здесь вам не нужны ярлыки.
anchor_output = ... # shape [None, 128]
positive_output = ... # shape [None, 128]
negative_output = ... # shape [None, 128]
d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)
loss = tf.maximum(0., margin + d_pos - d_neg)
loss = tf.reduce_mean(loss)
Настоящая проблема при использовании потери триплета или контрастных потерь в TensorFlow , как пробовать триплеты или пары. Я сосредоточусь на создании триплетов, потому что это сложнее, чем генерация пар.
Самый простой способ - генерировать их за пределами графика Tensorflow, т.е. на питоне и передавать их в сеть через заполнители. В основном вы выбираете изображения 3 за раз, причем первые два из одного и того же класса и третий из другого класса. Затем мы выполняем прямое соединение этих триплетов и вычисляем потерю триплета.
Проблема в том, что генерация триплетов сложна. Мы хотим, чтобы они были действительными триплетами, тройками с положительной потерей (в противном случае потеря равна 0, а сеть не учится).
Чтобы узнать, хорош ли триплет или нет, вам нужно вычислить его потерю, поэтому вы уже делаете одну прямую через сеть...
Ясно, что реализация триплетной потери в Tensorflow сложна, и есть способы сделать ее более эффективной, чем выборка на python, но для объяснения им потребуется целая запись в блоге!