代码之家  ›  专栏  ›  技术社区  ›  Karnivaurus

还原网络时,在还原的图形中找不到操作

  •  0
  • Karnivaurus  · 技术社区  · 6 年前

    使用tensorflow 1.9,我想在一个python文件中训练一个神经网络,然后使用另一个python文件恢复网络。我试图用一个简单的例子来实现这一点,但是当我试图加载“预测”操作时,我收到一个错误。具体来说,错误是: KeyError: "The name 'prediction' refers to an Operation not in the graph." .

    下面是我用来训练和保存网络的python文件。它生成一些示例数据,训练一个简单的神经网络,然后保存网络的每个时代。

    import numpy as np
    import tensorflow as tf
    
    input_data = np.zeros([100, 10])
    label_data = np.zeros([100, 1])
    for i in range(100):
        for j in range(10):
            input_data[i, j] = i * j / 1000
        label_data[i] = 2 * input_data[i, 0] + np.random.uniform(0.01)
    
    input_placeholder = tf.placeholder(tf.float32, shape=[None, 10], name='input_placeholder')
    label_placeholder = tf.placeholder(tf.float32, shape=[None, 1], name='label_placeholder')
    
    x = tf.layers.dense(inputs=input_placeholder, units=10, activation=tf.nn.relu)
    x = tf.layers.dense(inputs=x, units=10, activation=tf.nn.relu)
    prediction = tf.layers.dense(inputs=x, units=1, name='prediction')
    
    loss_op = tf.reduce_mean(tf.square(prediction - label_placeholder))
    train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_op)
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch_num in range(100):
            _, loss = sess.run([train_op, loss_op], feed_dict={input_placeholder: input_data, label_placeholder: label_data})
            print('epoch ' + str(epoch_num) + ', loss = ' + str(loss))
            saver.save(sess, '../Models/model', global_step=epoch_num + 1)
    

    下面是我用来恢复网络的python文件。它加载输入和输出占位符,以及进行预测所需的操作。然而,即使我将一个操作命名为 prediction 在上面的培训代码中,下面的代码似乎无法在加载的图表中找到此操作。

    import tensorflow as tf
    import numpy as np
    
    input_data = np.zeros([100, 10])
    for i in range(100):
        for j in range(10):
            input_data[i, j] = i * j / 1000
    
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('../Models/model-99.meta')
        saver.restore(sess, '../Models/model-99')
        graph = tf.get_default_graph()
        input_placeholder = graph.get_tensor_by_name('input_placeholder:0')
        label_placeholder = graph.get_tensor_by_name('label_placeholder:0')
        prediction = graph.get_operation_by_name('prediction')
        pred = sess.run([prediction], feed_dict={input_placeholder: input_data})
    

    为什么此代码找不到此操作,我应该如何更正代码?

    1 回复  |  直到 6 年前
        1
  •  1
  •   DocDriven    6 年前

    prediction = graph.get_tensor_by_name('prediction/BiasAdd:0')
    

    prediction.name tf.get_tensor_by_name