TensorFlow сохранение/загрузка графика из файла
Из того, что я собрал до сих пор, существует несколько разных способов сброса графика TensorFlow в файл, а затем загрузка его в другую программу, но я не смог найти четкие примеры/информацию о том, как они работают, Я уже знаю это:
- Сохраните переменные модели в файл контрольной точки (.ckpt) с помощью
tf.train.Saver()
и восстановите их позже (source)
- Сохраните модель в файле .pb и загрузите ее обратно с помощью
tf.train.write_graph()
и tf.import_graph_def()
(source)
- Загрузите модель из файла .pb, переустановите ее и выгрузите в новый .pb файл с помощью Bazel (source)
- Зафиксируйте график, чтобы сохранить график и вес вместе (источник)
- Используйте
as_graph_def()
для сохранения модели и для весов/переменных, сопоставьте их с константами (источник)
Однако я не смог прояснить несколько вопросов относительно этих разных методов:
- Что касается файлов контрольных точек, они сохраняют только подготовленные веса модели? Могут ли файлы контрольных точек загружаться в новую программу и использоваться для запуска модели, или они просто служат в качестве способов сохранения весов в модели в определенное время/этап?
- Что касается
tf.train.write_graph()
, также сохраняются ли весы/переменные?
- Что касается Bazel, может ли он только сохранить/загрузить из .pb файлов для переподготовки? Есть ли простая команда Bazel, чтобы сбрасывать граф в .pb?
- Что касается замораживания, может ли загруженный замороженный граф использовать
tf.import_graph_def()
?
- Демонстрация Android для загрузки TensorFlow в модели Google Inception из файла .pb. Если бы я хотел подставить свой собственный .pb файл, как бы я это сделал? Должен ли я изменить любой собственный код/методы?
- В общем, какая именно разница между всеми этими методами? Или более широко, в чем разница между
as_graph_def()
/. Ckpt/.pb?
Короче говоря, то, что я ищу, - это метод, который позволяет сохранить как график (как в, различные операции и т.д.), так и его вес/переменные в файл, который затем можно использовать для загрузки графика и весов в другую программу для использования (не обязательно продолжение/переподготовка).
Документация по этой теме не очень проста, поэтому любые ответы/информация были бы оценены.
Ответы
Ответ 1
Есть много способов подойти к проблеме сохранения модели в TensorFlow, что может сделать ее несколько запутанной. Принимая каждый из ваших вопросов:
-
Файлы контрольных точек (создаваемые, например, путем вызова saver.save()
объекта tf.train.Saver
) содержат только веса и любые другие переменные, определенные в одной и той же программе. Чтобы использовать их в другой программе, вы должны повторно создать связанную структуру графа (например, запустив код для его сборки снова или вызвав tf.import_graph_def()
), который сообщает TensorFlow, что делать с этими весами. Обратите внимание, что вызов saver.save()
также создает файл, содержащий MetaGraphDef
, который содержит график и сведения о том, как связать веса с контрольной точки с этим графом. Дополнительную информацию см. В руководстве.
-
tf.train.write_graph()
записывает только структуру графа; а не веса.
-
Bazel не связан с чтением или написанием графиков TensorFlow. (Возможно, я неправильно понимаю ваш вопрос: не стесняйтесь прояснить это в комментарии.)
-
Замороженный график можно загрузить с помощью tf.import_graph_def()
. В этом случае весы (обычно) встроены в график, поэтому вам не нужно загружать отдельную контрольную точку.
-
Основное изменение заключалось бы в обновлении имен тензора (ов), которые подаются в модель, и имен тензора (ов), которые извлекаются из модели. В демоверсии TensorFlow Android это будет соответствовать inputName
и outputName
, которые передаются TensorFlowClassifier.initializeTensorFlow()
.
-
GraphDef
- это структура программы, которая обычно не изменяется в процессе обучения. Контрольная точка представляет собой моментальный снимок состояния процесса обучения, который обычно изменяется на каждом этапе учебного процесса. В результате TensorFlow использует разные форматы хранения данных этих типов, а низкоуровневый API предоставляет различные способы их сохранения и загрузки. Библиотеки более высокого уровня, такие как MetaGraphDef
библиотеки, Keras и skflow на основе этих механизмов, чтобы обеспечить более удобные способы сохранения и восстановления целой модели.
Ответ 2
Вы можете попробовать следующий код:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)