Как разбить наборы данных Tensorflow?
У меня есть набор данных tensorflow, основанный на одном файле.tfrecord. Как разбить набор данных на тестовые и обучающие наборы данных? Например, 70% поезда и 30% тест?
Редактировать:
Моя версия Tensorflow: 1.8 Я проверил, нет функции split_v, как указано в возможном дубликате. Также я работаю с файлом tfrecord.
Ответы
Ответ 1
Вы можете использовать Dataset.take()
и Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
Для большей общности, я привел пример использования 70/15/15 split/val/test split, но если вам не нужен тест или набор значений, просто проигнорируйте последние 2 строки.
Возьмите:
Создает набор данных с максимальным количеством элементов из этого набора данных.
Пропустить:
Создает набор данных, который пропускает количество элементов из этого набора данных.
Вы также можете заглянуть в Dataset.shard()
:
Создает набор данных, который включает только 1/num_shards этого набора данных.
Ответ 2
Этот вопрос похож на этот и этот, и, боюсь, у нас пока нет удовлетворительного ответа.
Использование take()
и skip()
требует знания размера набора данных. Что если я этого не знаю или не хочу это выяснять?
Использование shard()
дает только 1 / num_shards
набора данных. Что, если я хочу отдохнуть?
Ниже я пытаюсь представить лучшее решение, протестированное только на TensorFlow 2. Предполагая, что у вас уже есть перемешанный набор данных, вы можете использовать filter()
, чтобы разделить его на две части:
import tensorflow as tf
all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
.shuffle(10, reshuffle_each_iteration=False)
test_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 == 0) \
.map(lambda x,y: y)
train_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 != 0) \
.map(lambda x,y: y)
for i in test_dataset:
print(i)
print()
for i in train_dataset:
print(i)
Параметр reshuffle_each_iteration=False
важен. Это гарантирует, что исходный набор данных перемешивается один раз и не более. В противном случае два результирующих набора могут иметь некоторые совпадения.
Используйте enumerate()
, чтобы добавить индекс.
Используйте filter(lambda x,y: x % 4 == 0)
, чтобы взять 1 образец из 4. Аналогично, x % 4 != 0
заберет 3 из 4.
Используйте map(lambda x,y: y)
, чтобы убрать указатель и восстановить исходный образец.
В этом примере достигается разделение 75/25.
x % 5 == 0
и x % 5 != 0
дают разделение 80/20.
Если вы действительно хотите разделить 70/30, x % 10 < 3
и x % 10 >= 3
должны это сделать.
UPDATE:
Начиная с TensorFlow 2.0.0, приведенный выше код может вызывать некоторые предупреждения из-за ограничений автографа. Чтобы устранить эти предупреждения, объявите все лямбда-функции отдельно:
def is_test(x, y):
return x % 4 == 0
def is_train(x, y):
return not is_test(x, y)
recover = lambda x,y: y
test_dataset = all.enumerate() \
.filter(is_test) \
.map(recover)
train_dataset = all.enumerate() \
.filter(is_train) \
.map(recover)
Это не дает предупреждения на моей машине. И превращение is_train()
в not is_test()
определенно хорошая практика.