Внедрение Tensorflow слова2vec
Учебник Tensorflow здесь относится к их базовой реализации, которую вы можете найти на github здесь, где авторы Tensorflow реализуют обучение/оценку внедрения векторного вложения word2vec с помощью модели Skipgram.
Мой вопрос о фактической генерации (целевых, контекстных) пар в функции generate_batch()
.
В эта строка Авторы тензорного потока произвольно выбирают соседние целевые индексы из индекса слова "центр" в скользящем окне слов.
Однако они также сохраняют структуру данных targets_to_avoid
, к которой они сначала добавляют "центральное" контекстное слово (что, конечно, мы надеваем 't хочу попробовать), но ТАКЖЕ другие слова после их добавления.
Мои вопросы таковы:
- Почему выборка из этого скользящего окна вокруг слова, почему бы просто не использовать петлю и использовать их все, а не выборку? Кажется странным, что они будут беспокоиться о производительности/памяти в
word2vec_basic.py
(их "базовая" реализация).
- Каким бы ни был ответ на вопрос 1), почему они берут выборку и отслеживают, что они выбрали с помощью
targets_to_avoid
? Если бы они хотели по-настоящему случайным образом, они использовали бы выбор с заменой, и если бы они хотели, чтобы у них были все варианты, они должны были просто использовать цикл и получить их все в первую очередь!
- Работает ли этот встроенный tf.models.embedding.gen_word2vec? Если да, где я могу найти исходный код? (не удалось найти файл .py в репозитории Github)
Спасибо!
Ответы
Ответ 1
Я попробовал свой предложенный способ создания пакетов - с циклом и с использованием всего пропуска. Результаты:
1. Быстрая генерация пакетов
Для размера партии 128 и окна пропуска 5
- создание пакетов путем циклического перебора данных по одному занимает 0,73 с на 10000 партий
- создание пакетов с кодом учебника и
num_skips=2
занимает 3.59s за 10 000 партий
2. Более высокая опасность переоснащения
Сохраняя остальную часть кода учебника, я тренировал модель в обоих направлениях и регистрировал среднюю потерю каждые 2000 шагов:
![введите описание изображения здесь]()
Этот шаблон повторялся несколько раз. Он показывает, что использование 10 выборок на слово вместо 2 может привести к переобучению.
Вот код, который я использовал для создания пакетов. Он заменяет учебник generate_batch
.
data_index = 0
def generate_batch(batch_size, skip_window):
global data_index
batch = np.ndarray(shape=(batch_size), dtype=np.int32) # Row
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) # Column
# For each word in the data, add the context to the batch and the word to the labels
batch_index = 0
while batch_index < batch_size:
context = data[get_context_indices(data_index, skip_window)]
# Add the context to the remaining batch space
remaining_space = min(batch_size - batch_index, len(context))
batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
labels[batch_index:batch_index + remaining_space] = data[data_index]
# Update the data_index and the batch_index
batch_index += remaining_space
data_index = (data_index + 1) % len(data)
return batch, labels
Изменить: get_context_indices
- это простая функция, которая возвращает срез индекса в skip_window вокруг data_index. Дополнительную информацию см. В документации slice().
Ответ 2
Существует параметр с именем num_skips
, который обозначает количество (входных, выходных) пар, генерируемых из одного окна: [skip_window target skip_window]. Таким образом, num_skips
ограничивает количество контекстных слов, которые мы будем использовать в качестве выходных слов. И поэтому функция generate_batch assert num_skips <= 2*skip_window
. Код просто случайно выбирает num_skip
контекстные слова для построения пар тренировок с целью.
Но я не знаю, как num_skips
влияет на производительность.