代码之家  ›  专栏  ›  技术社区  ›  madsthaks

问题保存和恢复tensorflow模型(LSTM)

  •  0
  • madsthaks  · 技术社区  · 7 年前

    tensorflow website 作为一种资源。

    graph = tf.Graph()
    
    with graph.as_default():
        global_step = tf.Variable(0)
    
        data = tf.placeholder(tf.float32, [batch_size, len_section, char_size])
        labels = tf.placeholder(tf.float32, [batch_size, char_size])
    
        .....
    
        #Reset at the beginning of each test
        reset_test_state = tf.group(test_output.assign(tf.zeros([1, hidden_nodes])), 
                                    test_state.assign(tf.zeros([1, hidden_nodes])))
    
        #LSTM
        test_output, test_state = lstm(test_data, test_output, test_state)
        test_prediction = tf.nn.softmax(tf.matmul(test_output, w) + b)
    
        saver = tf.train.Saver()
    

    在这里,我正在训练我的模型,并在30次迭代中保存一个检查点

    with tf.Session(graph = graph) as sess:
        tf.global_variables_initializer().run()
        offset = 0
    
        for step in range(10000):
    
            offset = offset % len(X)
    
            if offset <= (len(X) - batch_size):
    
                batch_data = X[offset: offset + batch_size]
                batch_labels = y[offset:offset+batch_size]
                offset += batch_size
    
            else: 
                to_add = batch_size - (len(X) - offset)
                batch_data = np.concatenate((X[offset: len(X)], X[0: to_add]))
                batch_labels = np.concatenate((y[offset: len(X)], y[0: to_add]))
                offset = to_add
    
            _, training_loss = sess.run([optimizer, loss], feed_dict = {data : batch_data, labels : batch_labels})
    
            if step % 10 == 0:
                print('training loss at step %d: %.2f (%s)' % (step, training_loss, datetime.datetime.now()))
    
            if step % save_every == 0:
                saver.save(sess, checkpoint_directory + '/model.ckpt', global_step=step)
    
            if step == 30:
                break
    

    我查看该目录,创建了以下文件:

    enter image description here

    在这里,我应该恢复经过训练的模型并对其进行测试:

    with tf.Session(graph=graph) as sess:
        #standard init step
        offset = 0
        saver = tf.train.Saver()
        saver.restore(sess, "/ckpt/model-150.meta")
        tf.global_variables_initializer().run()
    
        test_start = "I plan to make this world a better place "
        test_generated = test_start
    
    ....
    

    执行此操作后,我得到以下错误:

    DataLossError (see above for traceback): Unable to open table file /ckpt/model.ckpt-30.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
    

    我不太确定我做错了什么。教程看起来很简单,但我显然遗漏了一些东西。任何形式的反馈都将不胜感激。

    1 回复  |  直到 7 年前
        1
  •  1
  •   Alexandre Passos    7 年前

    首先,请注意,如果您在从检查点恢复后初始化所有变量,您将获得它们的随机初始值,而不是经过训练的值。

    其次,如果您使用 tf.estimator.Estimator

    第三,我不明白你是怎么通过的 model-150.meta 要恢复,但看到有关的错误 model-30.meta model-30 (不带.meta后缀)。