Ответ 1
Вы можете использовать инструмент inspect_checkpoint.py
.
Я хочу видеть переменные, которые сохраняются в контрольной точке tensorflow вместе с их значениями. Как найти имена переменных, которые сохраняются в контрольной точке тензорного потока?
EDIT:
Я использовал tf.train.NewCheckpointReader
, который объясняется здесь. Но это не дано в документации тензорного потока. Есть ли другой способ?
`
import tensorflow as tf
v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0")
v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32,
name="v1")
init_all_op = tf.initialize_all_variables()
save = tf.train.Saver({"v0": v0, "v1": v1})
checkpoint_path = os.path.join(model_dir, "model.ckpt")
with tf.Session() as sess:
sess.run(init_all_op)
# Saves a checkpoint.
save.save(sess, checkpoint_path)
# Creates a reader.
reader = tf.train.NewCheckpointReader(checkpoint_path)
print('reder:\n', reader)
# Verifies that the tensors exist.
print('is exist v0?', reader.has_tensor("v0"))
print('is exist v1?', reader.has_tensor("v1"))
# Verifies that debug string contains the right strings.
debug_string = reader.debug_string()
print('\n All Variables: \n', debug_string)
# Verifies get_variable_to_shape_map() returns the correct information.
var_map = reader.get_variable_to_shape_map()
print('\n All Variables information :\n', var_map)
# Verifies get_tensor() returns the tensor value.
v0_tensor = reader.get_tensor("v0")
v1_tensor = reader.get_tensor("v1")
print('\n returns the v0 tensor value:\n', v0_tensor)
print('\n returns the v1 tensor value:\n', v1_tensor)
`
Вы можете использовать инструмент inspect_checkpoint.py
.
Использование примера:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
Обновление: all_tensors
аргумент был добавлен в print_tensors_in_checkpoint_file
, поскольку Tensorflow 0.12.0-rc0, чтобы вы может потребоваться добавить all_tensors=False
или all_tensors=True
, если это необходимо.
Альтернативный метод:
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
Надеюсь, что это поможет.
Добавление к предыдущему ответу:
Если модель сохраняется в формате V2
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
Вводимое имя контрольной точки должно быть только префиксом
print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True)
источник: by @LingjiaDeng в https://github.com/tensorflow/tensorflow/issues/7696