Ранняя остановка с помощью tf.estimator, как?

Я использую tf.estimator в tf.estimator 1.4 и tf.estimator.train_and_evaluate отлично, но мне нужна ранняя остановка. Какой предпочтительный способ добавить это?

Я предполагаю, что для этого есть некоторый tf.train.SessionRunHook. Я видел, что был старый пакет с пакетом ValidationMonitor который, казалось, был на ранней стадии остановки, но, похоже, это не похоже на 1.4. Или в будущем предпочтительнее будет полагаться на tf.keras (с которым ранняя остановка действительно проста), а не tf.estimator/tf.layers/tf.data, возможно?

Ответы

Ответ 1

Хорошие новости! tf.estimator теперь есть ранняя остановка поддержки на хозяине, и похоже, что она будет в 1.10.

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

Ответ 2

Во-первых, вы должны назвать потерю, чтобы сделать ее доступной для раннего прерывания вызова. Если ваша переменная потерь называется "потеря" в оценке, линия

copyloss = tf.identity(loss, name="loss")

прямо под ним будет работать.

Затем создайте крючок с этим кодом.

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

это сравнивает экспоненциально сглаженную проверку потерь с ее наименьшим значением, и если она выше по допускам, она перестает тренироваться. Если он останавливается слишком рано, повышение допуска и сглаживания заставит его остановиться позже. Продолжайте сглаживание ниже одного, или оно никогда не прекратится.

Вы можете заменить логику в after_run чем-то другим, если хотите остановить на основе другого условия.

Теперь добавьте этот крючок в спецификацию оценки. Ваш код должен выглядеть примерно так:

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

Важное примечание. Функция, run_context.request_stop() нарушена в вызове train_and_evaluate и не прекращает обучение. Итак, я поднял значение ошибки, чтобы остановить обучение. Таким образом, вы должны обернуть вызов train_and_evaluate в блок catch try следующим образом:

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

если вы этого не сделаете, код будет сбой с ошибкой, когда обучение остановится.

Ответ 3

Да, есть tf.train.StopAtStepHook:

Этот запрос крюка останавливается после выполнения нескольких шагов или последнего шага. Можно указать только один из двух параметров.

Вы также можете расширить его и реализовать свою собственную стратегию остановки на основе результатов шага.

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()

Ответ 4

Другой вариант, который не использует перехватчики, заключается в создании tf.contrib.learn.Experiment (который, кажется, даже в contrib, также поддерживает новый tf.estimator.Estimator).

Затем тренируйтесь через (по-видимому, экспериментальный) метод continuous_train_and_eval с соответствующим образом настроенным continuous_eval_predicate_fn.

В соответствии с документом tensorflow continuous_eval_predicate_fn равен

Функция предиката, определяющая, следует ли продолжить eval после каждой итерации.

и вызвал с eval_results от последнего прогона оценки. Для ранней остановки используйте настраиваемую функцию, которая сохраняет как состояние текущий лучший результат и счетчик и возвращает False при достижении условия для ранней остановки.

Примечание добавлено: этот подход будет использовать устаревшие методы w/tensorflow 1.7 (все tf.contrib.learn устарели с этой версии и далее: https://www.tensorflow.org/api_docs/python/tf/contrib/learn)