Как создать оптимизатор в Tensorflow

Я хочу написать новый алгоритм оптимизации для моей сети на Tensorflow. Я надеюсь реализовать алгоритм оптимизации Levenberg Marquardt, который теперь исключается из TF API. Я нашел плохую документацию о том, как писать настраиваемый оптимизатор, поэтому я спрашиваю, может ли кто-нибудь дать мне какие-либо советы. Спасибо.

Ответы

Ответ 1

Простейшим примером оптимизатора является, вероятно, оптимизатор спуска градиента. Он показывает, как создать экземпляр базового класса оптимизатора. Документация базового класса оптимизатора объясняет, что делают методы.

Сторона-оптимизатор на основе python добавляет новые узлы в график, который вычисляет и применяет градиенты, возвращаемые обратно. Он поставляет параметры, которые передаются в операционные системы, и делает некоторые из высокоуровневого управления оптимизатором. Затем вам понадобится фактическая операция "Применить".

Ops имеют как питон, так и компонент С++. Написание учебного курса является тем же (но специализированным), что и общий процесс добавления Op to TensorFlow.

Для примера набора обучающих операций, которые вычисляют и применяют градиенты, см. python/training/training_ops.py - это клей Python для реальных тренировок. Обратите внимание, что здесь код в основном касается вывода формы - вычисление будет в С++.

Фактическая математика для применения градиентов обрабатывается Op (напомним, что в общем случае ops написаны на С++). В этом случае операции op градиентов применяются в core/kernels/training_ops.cc. Вы можете увидеть, например, реализацию ApplyGradientDescentOp, которая ссылается на функтор ApplyGradientDescent:

var.device(d) -= grad * lr();

Реализация самой Op следует за реализацией любого другого op, как описано в документах add-an-op.

Ответ 2

Перед запуском сеанса Tensorflow необходимо запустить Оптимизатор, как показано ниже:

# Gradient Descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

tf.train.GradientDescentOptimizer - это объект класса GradientDescentOptimizer, и, как следует из названия, он реализует алгоритм спуска градиента.

Метод minimize() вызывается с параметром "стоимость" в качестве параметра и состоит из двух методов compute_gradients(), а затем apply_gradients().

Для большинства (пользовательских) реализаторов оптимизатора необходимо адаптировать метод apply_gradients().

Этот метод основан на (новом) Оптимизаторе (классе), который мы создадим, для реализации следующих методов: _create_slots(), _prepare(), _apply_dense() и _apply_sparse().

  • _create_slots() и _prepare() создать и инициализировать дополнительные переменные, такие как импульс.

  • _apply_dense() и _apply_sparse() реализовать фактические операционные системы, которые обновляют переменные.

Ops обычно записываются на С++. Без необходимости изменять заголовок С++ самостоятельно, вы все равно можете вернуть оболочку python некоторых Ops с помощью этих методов. Это делается следующим образом:

def _create_slots(self, var_list):
   # Create slots for allocation and later management of additional 
   # variables associated with the variables to train.
   # for example: the first and second moments.
   '''
   for v in var_list:
      self._zeros_slot(v, "m", self._name)
      self._zeros_slot(v, "v", self._name)
   '''
def _apply_dense(self, grad, var):
   #define your favourite variable update
    # for example:
   '''
   # Here we apply gradient descents by substracting the variables 
   # with the gradient times the learning_rate (defined in __init__)
   var_update = state_ops.assign_sub(var, self.learning_rate * grad) 
   '''
   #The trick is now to pass the Ops in the control_flow_ops and 
   # eventually groups any particular computation of the slots your 
   # wish to keep track of:
   # for example:    
   '''
    m_t = ...m... #do something with m and grad
    v_t = ...v... # do something with v and grad
    '''
  return control_flow_ops.group(*[var_update, m_t, v_t])

Более подробное объяснение с примером см. в этом блоге https://www.bigdatarepublic.nl/custom-optimizer-in-tensorflow/