Замените мониторы проверки с помощью tf.train.SessionRunHook при использовании оценочных

Я запускаю DNNClassifier, для которого я контролирую точность во время обучения. monitors.ValidationMonitor от contrib/learn работает отлично, в моей реализации я его определяю:

validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)

а затем используйте вызов из:

clf.fit(input_fn=lambda: input_fn(A, Cl2),
            steps=1000, monitors=[validation_monitor])

где:

clf = tensorflow.contrib.learn.DNNClassifier(...

Это прекрасно работает. Тем не менее, проверки достоверности выглядят устаревшими и аналогичную функциональность заменяют на tf.train.SessionRunHook.

Я новичок в TensorFlow, и мне не кажется тривиальным, как будет выглядеть такая замещающая реализация. Любое предложение высоко ценится. Опять же, мне нужно пройти проверку обучения после определенного количества шагов. Большое спасибо заранее.

Ответы

Ответ 1

Там есть недокументированная утилита monitors.replace_monitors_with_hooks(), которая преобразует мониторы в крючки. Метод принимает (i) список, который может содержать как мониторы, так и крючки, и (ii) Оценщик, для которого будут использоваться крючки, а затем возвращает список крючков, обертывая SessionRunHook вокруг каждого монитора.

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib

clf = tf.estimator.Estimator(...)

list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)

На самом деле это не является истинным решением проблемы полной замены ValidationMonitor - мы просто обертываем его ненулевой функцией. Тем не менее, я могу сказать, что это до сих пор работало на меня, поскольку оно поддерживало всю необходимую мне функциональность из ValidationMonitor (т.е. Оценивая все n шагов, раннюю остановку с использованием метрики и т.д.).

Еще одна вещь - использовать этот крючок вам нужно будет обновить с tf.contrib.learn.Estimator (который принимает только мониторы) до более полноценного и официального tf.estimator.Estimator (который принимает только крючки). Итак, вы должны создать экземпляр своего классификатора как tf.estimator.DNNClassifier, а затем использовать его метод train() (это просто переименование fit()):

clf = tf.estimator.Estimator(...)

...

clf.train(
    input_fn=...
    ...
    hooks=hooks)