Ответ 1
Здесь мое решение, использующее контрольные точки V2, введенные в TF 0.12.
Нет необходимости преобразовывать все переменные в константы или заморозить график.
Просто для ясности контрольная точка V2 выглядит так в моем каталоге models
:
checkpoint # some information on the name of the files in the checkpoint
my-model.data-00000-of-00001 # the saved weights
my-model.index # probably definition of data layout in the previous file
my-model.meta # protobuf of the graph (nodes and topology info)
Часть (сохранение) Python
with tf.Session() as sess:
tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')
Если вы создаете Saver
с помощью tf.trainable_variables()
, вы можете сэкономить себе головную боль и пространство для хранения. Но, возможно, некоторым более сложным моделям нужны все данные для сохранения, а затем удалите этот аргумент в Saver
, просто убедитесь, что вы создаете Saver
после, ваш график создан. Также очень важно дать всем переменным/слоям уникальные имена, иначе вы можете запустить различные проблемы.
Часть (вывод) Python
with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)
Часть С++ (вывод)
Обратите внимание, что checkpointPath
не путь к любому из существующих файлов, просто их общий префикс. Если вы по ошибке поместите туда путь к файлу .index
, TF не скажет вам, что это было неправильно, но он умрет во время вывода из-за неинициализированных переменных.
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);