Какова последовательность вызываемой функции-члена SessionRunHook?

После прочтения API DOC я также не могу понять использование SessionRunHook. Например, какова последовательность функций участника SessionRunHook? Это after_create_session → before_run → begin → after_run → end? И я не могу найти учебник с подробными примерами, есть ли более подробное объяснение?

Ответы

Ответ 1

Вы можете найти учебник здесь, немного длинный, но вы можете перейти к части построения сети. Или вы можете прочитать мое краткое изложение ниже, основываясь на моем опыте.

Во-первых, вместо обычного Session следует использовать MonitoredSession.

SessionRunHook расширяет session.run() вызовы для MonitoredSession.

Тогда некоторые общие классы SessionRunHook можно найти здесь. Простым является LoggingTensorHook, но вы можете добавить следующую строку после импорта для просмотра журналов во время работы:

tf.logging.set_verbosity(tf.logging.INFO)

Или у вас есть возможность реализовать свой собственный класс SessionRunHook. Простой из учебника cifar10

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

где loss определяется вне класса. Этот _LoggerHook использует print для печати информации, в то время как LoggingTensorHook использует tf.logging.INFO.

Наконец, для лучшего понимания, как это работает, порядок выполнения представлен псевдокодом с MonitoredSession здесь:

  call hooks.begin()
  sess = tf.Session()
  call hooks.after_create_session()
  while not stop is requested:  # py code: while not mon_sess.should_stop():
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()

Надеюсь, это поможет.

Ответ 2

tf.SessionRunHook позволяет вам добавлять свой код в течение каждой команды запуска сеанса, выполняемой в коде. Чтобы понять это, я создал простой пример ниже:

  1. Мы хотим напечатать значения потерь после каждого обновления параметров.
  2. Для этого мы будем использовать SessionRunHook.

Создать график тензорного потока

import tensorflow as tf
import numpy as np

x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)

Создание крюка

class _Hook(tf.train.SessionRunHook):
  def __init__(self, loss):
    self.loss = loss

  def begin(self):
    pass

  def before_run(self, run_context):
    return tf.train.SessionRunArgs(self.loss)  

  def after_run(self, run_context, run_values):
    loss_value = run_values.results
    print("loss value:", loss_value)

Создание контролируемого сеанса с помощью hook

sess = tf.train.MonitoredSession(hooks=[_Hook(loss)])

поезд

for _ in range(10):
  x_ = np.random.random((10, 2))
  sess.run(optimizer, {x: x_})
# Output
loss value: 21.244701
loss value: 19.39169
loss value: 16.02665
loss value: 16.717144
loss value: 15.389178
loss value: 16.23935
loss value: 14.299083
loss value: 9.624525
loss value: 5.654896
loss value: 10.689494