Ответ 1
Обратите внимание, что weighted_cross_entropy_with_logits
- взвешенный вариант sigmoid_cross_entropy_with_logits
. Сигмовидная кросс-энтропия обычно используется для двоичной классификации. Да, он может обрабатывать несколько ярлыков, но сигмовидная кросс-энтропия в основном делает (двоичное) решение для каждого из них - например, для сети распознавания лиц эти (не взаимоисключающие) метки могут быть: "Означает ли предмет очки?", "Является ли тема женщиной?" И т.д.
В двоичной классификации (ов) каждый выходной канал соответствует двоичному (мягкому) решению. Поэтому взвешивание должно происходить при вычислении потерь. Это то, что weighted_cross_entropy_with_logits
делает, взвешивая один член кросс-энтропии над другим.
Во взаимоисключающей многосегментной классификации мы используем softmax_cross_entropy_with_logits
, которая ведет себя по-разному: каждый выходной канал соответствует счету кандидата класса. Решение приходит после, путем сравнения соответствующих выходов каждого канала.
Взвешивание до окончательного решения - это просто вопрос изменения оценок перед их сравнением, как правило, путем умножения с весами. Например, для тройной задачи классификации
# your class weights
class_weights = tf.constant([[1.0, 2.0, 3.0]])
# deduce weights for batch samples based on their true label
weights = tf.reduce_sum(class_weights * onehot_labels, axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)
# apply the weights, relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)
Вы также можете положиться на tf.losses.softmax_cross_entropy
для обработки последних трех шагов.
В вашем случае, когда вам нужно решить проблему дисбаланса данных, вес класса может действительно быть обратно пропорциональным их частоте в данных вашего поезда. Нормализация их так, чтобы они суммировались с одним или с числом классов, также имеет смысл.
Обратите внимание, что в приведенном выше случае мы оштрафовали убыток на основе истинной метки образцов. Мы также могли бы оштрафовать потери на основе оцененных меток, просто определяя
weights = class_weights
а остальная часть кода не должна изменяться благодаря магии вещания.
В общем случае вы хотели бы, чтобы весы зависели от той ошибки, которую вы совершаете. Другими словами, для каждой пары меток X
и Y
вы можете выбрать, как оштрафовать выбор метки X
, когда истинная метка Y
. В итоге вы получаете целую матрицу весов, которая приводит к тому, что weights
выше является полным тензором (num_samples, num_classes)
. Это немного зависит от того, что вы хотите, но может быть полезно знать, тем не менее, что только ваше определение тензора веса необходимо изменить в коде выше.