Где next_batch в учебнике TensorFlow batch_xs, batch_ys = mnist.train.next_batch (100)?
Я тестирую учебник TensorFlow и не понимаю, откуда идет next_batch в этой строке?
batch_xs, batch_ys = mnist.train.next_batch(100)
Я посмотрел
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
И не видел next_batch там.
Теперь, когда вы тестируете next_batch в моем собственном коде, я получаю
AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'
Итак, я хотел бы понять, откуда происходит next_batch?
Ответы
Ответ 1
next_batch
- это метод класса DataSet
(см. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py для получения дополнительной информации о том, что в классе).
Когда вы загружаете данные mnist и назначаете его переменной mnist
с помощью:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Посмотрите на класс mnist.train
. Вы можете увидеть его, набрав:
print mnist.train.__class__
Вы увидите следующее:
<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>
Поскольку mnist.train
является экземпляром класса DataSet
, вы можете использовать функцию класса next_batch
. Для получения дополнительной информации о классах ознакомьтесь с документацией.
Ответ 2
Просматривая репозиторий tensorflow, он, кажется, возникает здесь:
https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905
Однако, если вы хотите реализовать его в своем собственном коде (для своего собственного набора данных), было бы гораздо проще записать его непосредственно в объекте набора данных, как и я. Насколько я понимаю, это способ перетасовать весь набор данных и вернуть $mini_batch_size количество выборок из перетасованного набора данных.
Здесь некоторый псевдокод:
shuffle data.x and data.y while retaining relation
return [data.x[:mb_n], data.y[:mb_n]]
Ответ 3
Вы можете просто использовать справочную функцию:
help(tf.contrib.learn.datasets.mnist.DataSet.next_batch)
и получить документ функции next_batch