代码之家  ›  专栏  ›  技术社区  ›  Stefan Falk

从检查点加载的模型似乎未正确初始化

  •  0
  • Stefan Falk  · 技术社区  · 6 年前

    我正在从这样的检查点加载/还原模型:

    ckpt_path = tf.train.latest_checkpoint(self.checkpoints_dir)
    config = {'data_dir': os.path.dirname(self.vocab_filename), 'beam_size': 1, 'alpha': 0.6}
    
    # Restore
    graph = tf.Graph()
    session = tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True))
    
    with graph.as_default():
        _, self.input_node_name, self.output_node_name = load_translation_model(ckpt_path, config)
        meta_graph = tf.train.import_meta_graph(ckpt_path + '.meta')
        session.run(tf.global_variables_initializer())
        meta_graph.restore(session, ckpt_path)
    
    # Test restored model
    encoder_inputs = self.encode('how are you doing')
    
    sample = {'%s:0' % self.input_node_name: encoder_inputs}
    output_node = graph.get_tensor_by_name('%s:0' % self.output_node_name)
    
    result = session.run(output_node, feed_dict=sample)[0]
    decoded = self.decode(result)
    

    没有问题,除了 decoded 输出只是垃圾。对我来说,这看起来好像模型没有被正确地恢复,我运行的是随机初始化的变量-可能是这种情况吗?

    如果是的话,我在加载模型时哪里做错了?

    0 回复  |  直到 6 年前