Ответ 1
В примерах и учебниках Tensorflow заметным шаблоном для структурирования кода модели является разбиение модели на три функции:
-
inference(inputs, ...)
, который строит модель -
loss(logits, ...)
, который добавляет потери поверх логитов -
train(loss, ...)
, который добавляет учебные операции
При создании модели для обучения ваш код будет выглядеть примерно так:
inputs = tf.placeholder(...)
logits = mymodel.inference(inputs, ...)
loss = mymodel.loss(logits, ...)
train = mymodel.train(loss, ...)
Этот шаблон используется, например, для учебника CIFAR-10 (code, tutorial).
Одна вещь, которую можно было бы споткнуться, - это тот факт, что вы не можете делиться (Python) переменными между функциями inference
и loss
. Это не большая проблема, поскольку, поскольку Tensorflow предоставляет коллекции графиков именно для этого варианта использования, делая для более чистого дизайна (поскольку он делает вас группируйте свои вещи логически). Одним из основных прецедентов для этого является регуляризация:
Если вы используете модуль layers
(например, tf.layers.conv2d
), у вас уже есть то, что вам нужно, так как будут добавлены все санкции регуляризации (source) в коллекцию tf.GraphKeys.REGULARIZATION_LOSSES
по умолчанию. Например, когда вы это делаете:
conv1 = tf.layers.conv2d(
inputs,
filters=96,
kernel_size=11,
strides=4,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.contrib.layers.l2_regularizer(),
name='conv1')
Ваша потеря может выглядеть так:
def loss(logits, labels):
softmax_loss = tf.losses.softmax_cross_entropy(
onehot_labels=labels,
logits=logits)
regularization_loss = tf.add_n(tf.get_collection(
tf.GraphKeys.REGULARIZATION_LOSSES)))
return tf.add(softmax_loss, regularization_loss)
Если вы не используете модуль слоев, вам придется заполнить коллекцию вручную (так же, как в связанном фрагменте источника). В основном вы хотите добавить штрафы в коллекцию, используя tf.add_to_collection
:
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, reg_penalty)
С помощью этого вы можете рассчитать потерю, включая штрафы за регуляризацию, как указано выше.