Есть ли пример того, как создавать файлы protobuf, содержащие обучаемые графики Tensorflow
Я рассматриваю пример Google о том, как развернуть и использовать предварительно подготовленный график (модель) Tensorflow на Android, по адресу:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android
В этом примере используется файл .pb по адресу: [это ссылка на файл, который автоматически загружается]
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
В этом примере показано, как загрузить файл .pb в сеанс Tensorflow и использовать его для выполнения классификации, но не упоминает (?), как создать такой .pb файл после обучения графа (например, в Python).
Есть ли примеры того, как это сделать?
Ответы
Ответ 1
EDIT: freeze_graph.py
script, который является частью репозитория TensorFlow, теперь служит инструмент, который генерирует буфер протокола, представляющий "замороженную" обучаемую модель, из существующего TensorFlow GraphDef
и сохраненной контрольной точки. Он использует те же шаги, что описаны ниже, но гораздо проще в использовании.
В настоящее время процесс не очень хорошо документирован (и может быть уточнен), но приблизительные шаги заключаются в следующем:
- Создайте и подготовьте свою модель как
tf.Graph
под названием g_1
.
- Получить окончательные значения каждой из переменных и сохранить их как массивы numpy (используя
Session.run()
).
- В новом
tf.Graph
, называемом g_2
, создайте тензоры tf.constant()
для каждой из переменных, используя значение соответствующего массива numpy, выбранного на шаге 2.
-
Используйте tf.import_graph_def()
для копирования узлов из g_1
в g_2
и используйте аргумент input_map
для замены каждая переменная в g_1
с соответствующими тензорами tf.constant()
, созданная на шаге 3. Вы также можете использовать input_map
для указания нового входного тензора (например, заменяя введите < с tf.placeholder()
). Используйте аргумент return_elements
, чтобы указать имя прогнозируемого выходного тензора.
-
Вызвать g_2.as_graph_def()
, чтобы получить представление буфера в протоколе графика.
( ПРИМЕЧАНИЕ: Сгенерированный граф будет иметь дополнительные узлы в графике для обучения. Хотя он не является частью общедоступного API, вы можете использовать внутренний graph_util.extract_sub_graph()
, чтобы удалить эти узлы из графика.)
Ответ 2
В качестве альтернативы моему предыдущему ответу, используя freeze_graph()
, который хорош только, если вы называете его script, есть очень приятная функция, которая сделает весь тяжелый подъем для вас и подходит для вызова из вашего нормальный код обучения модели.
convert_variables_to_constants()
выполняет две вещи:
- Он замораживает вес, заменяя переменные константами
- Он удаляет узлы, которые не связаны с предсказанием вперед
Предполагая, что sess
- ваш tf.Session()
и "output"
- это имя вашего прогноза node, следующий код сериализует ваш минимальный график как в текстовый, так и в двоичный protobuf.
from tensorflow.python.framework.graph_util import convert_variables_to_constants
minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.proto', as_text=False)
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.txt', as_text=True)
Ответ 3
Я не мог понять, как реализовать метод, описанный mrry. Но вот как я это решил. Я не уверен, что это лучший способ решить проблему, но, по крайней мере, она решает ее.
Поскольку write_graph также может хранить значения констант, я добавил следующий код в python непосредственно перед написанием графика с помощью функции write_graph:
for v in tf.trainable_variables():
vc = tf.constant(v.eval())
tf.assign(v, vc, name="assign_variables")
Это создает константы, которые сохраняют значения переменных после обучения, а затем создают тензоры " assign_variables", чтобы назначить их переменным. Теперь, когда вы вызываете write_graph, он будет хранить значения переменных в файле в виде констант.
Единственной оставшейся частью является вызов этих тензоров " assign_variables" в коде c, чтобы убедиться, что ваши переменные назначены значениями констант, которые хранятся в файле. Вот один из способов сделать это:
Status status = NewSession(SessionOptions(), &session);
std::vector<tensorflow::Tensor> outputs;
char name[100];
for(int i = 0;status.ok(); i++) {
if (i==0)
sprintf(name, "assign_variables");
else
sprintf(name, "assign_variables_%d", i);
status = session->Run({}, {name}, {}, &outputs);
}
Ответ 4
Вот еще один ответ на @Mostafa. Несколько более простой способ запуска tf.assign
ops - сохранить их в tf.group
. Здесь мой код Python:
ops = []
for v in tf.trainable_variables():
vc = tf.constant(v.eval())
ops.append(tf.assign(v, vc));
tf.group(*ops, name="assign_trained_variables")
И в С++:
std::vector<tensorflow::Tensor> tmp;
status = session.Run({}, {}, { "assign_trained_variables" }, &tmp);
if (!status.ok()) {
// Handle error
}
Таким образом, у вас есть только один именованный оператор op для запуска на стороне С++, поэтому вам не нужно путаться с итерацией по узлам.
Ответ 5
Просто нашел этот пост, и это было очень полезно! Я также использую метод @Mostafa, хотя мой код на С++ немного отличается:
std::vector<string> names;
int node_count = graph.node_size();
cout << node_count << " nodes in graph" << endl;
// iterate all nodes
for(int i=0; i<node_count; i++) {
auto n = graph.node(i);
cout << i << ":" << n.name() << endl;
// if name contains "var_hack", add to vector
if(n.name().find("var_hack") != std::string::npos) {
names.push_back(n.name());
cout << "......bang" << endl;
}
}
session.Run({}, names, {}, &outputs);
NB Я использую "var_hack" как имя моей переменной в python
Ответ 6
Я нашел функцию freeze_graph()
в кодовой базе Tensorflow, которая может быть полезна при выполнении этого. Из того, что я понимаю, он меняет переменные с константами перед сериализацией GraphDef, поэтому, когда вы загружаете этот график из С++, у него нет никаких переменных, которые необходимо установить больше, и вы можете напрямую использовать их для предсказаний.
Существует также test для него и некоторое описание в Руководство по.
Это похоже на самый чистый вариант здесь.