tutorial
inpt = keras.layers.Input(shape = (28,28,1), name = "input_node")
x = keras.layers.Convolution2D(16, 2, padding = 'same', activation = 'relu')(inpt)
x = keras.layers.MaxPool2D(pool_size = 2)(x)
x = keras.layers.Convolution2D(32, 2, padding = 'same', activation = 'relu')(x)
x = keras.layers.MaxPool2D(pool_size = 2)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(128, activation = 'relu')(x)
output = keras.layers.Dense(10, activation = 'softmax', name = "output_node")(x)
model = keras.models.Model(inpt,output)
model.compile(optimizer = keras.optimizers.Adam(lr = 0.0001), loss = 'categorical_crossentropy', metrics = ['accuracy'])
然后用
model_to_estimator
estimator_model = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = './TF_MNIST')
这很管用,我可以使用:
estimator_model.train(input_fn = input_function(X_train,y_train,True))
不过,我想用
freeze_graph
checkpoint_state_name = "model.ckpt-21001.index"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"
input_graph_path = os.path.join('./TF_MNIST', input_graph_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = os.path.join('./TF_MNIST', checkpoint_state_name)
output_node_names = "output_node"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join('./TF_MNIST', output_graph_name)
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, input_checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, initializer_nodes = "input_node")
我选择了那个名字
output_graph.pb
用于生成的冻结图形的目标。
我得到以下错误:
ValueError Traceback (most recent call last)
<ipython-input-69-215edbaaf017> in <module>()
3 output_node_names, restore_op_name,
4 filename_tensor_name, output_graph_path,
----> 5 clear_devices, initializer_nodes = "input_node")
ValueError: No variables to save
在本教程的示例中,没有输入参数
initializer_nodes
所以我假设它是输入节点的名称。另外,当我使用的检查点文件不是
.index
文件,它提供了一个
Data loss
问题:
-
如何修复此错误?
-
.索引
-
本教程有一个
input_graph.pb
.pbtxt
-
tf.Session()
如果您对这些问题有任何帮助,我们将不胜感激。