Как восстановить модель Tensorflow из файла .pb в python?

У меня есть файл tenorflow.pb, который я хотел бы загрузить в DNN Python, восстановить график и получить прогнозы. Я делаю это, чтобы проверить, может ли созданный файл .pb сделать прогнозы похожими на обычную модель Saver.save().

Моя основная проблема в том, что я получаю совершенно разные значения прогнозов, когда я делаю их на Android с помощью вышеупомянутого файла .pb

Мой код создания файла .pb:

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
    )
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())

Итак, у меня есть две основные проблемы:

  1. Как я могу загрузить вышеупомянутый файл .pb в модель Python Tensorflow?
  2. Почему я получаю совершенно разные значения прогноза в Python и Android?

Ответы

Ответ 1

Следующий код прочитает модель и выведет имена узлов на графике.

import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

Вы правильно замораживаете график, поэтому вы получаете разные результаты, в основном веса не сохраняются в вашей модели. Вы можете использовать freeze_graph.py (ссылка) для получения правильно сохраненного графика.