代码之家  ›  专栏  ›  技术社区  ›  singrium

在同一个python进程中运行多个tensorflow会话时出错

  •  0
  • singrium  · 技术社区  · 6 年前

    我有一个具有此层次结构的项目:

    project
    ├── libs
    │   ├── __init__.py
    │   ├── sub_lib1
    │   │   ├── file1.py
    │   │   └── __init__.py
    │   └── sub_lib2
    │       ├── file2.py
    │       └── __init__.py
    └── main.py
    

    main.py的内容:

    from libs.sub_lib1.file1 import func1
    from libs.sub_lib2.file2 import func2
    
    
    #some code
    
    func1(parameters)
    
    #some code
    
    func2(parameters)
    
    #some code
    

    文件内容1.py:

    #import some packages
    import tensorflow as tf
    
    def func1(parameters):
    
        #some code
    
        config = tf.ConfigProto()
        config.gpu_options.allow_growth=True
        tf.reset_default_graph()
        x = tf.placeholder(tf.float32,shape=[None,IMG_SIZE_ALEXNET,IMG_SIZE_ALEXNET,3])
        y_true = tf.placeholder(tf.float32,shape=[None,output_classes])
        with tf.Session(config=config) as session:
            saver.restore(session, "path to the model1")
            k = session.run([tf.nn.softmax(y_pred)], feed_dict={x:test_x , hold_prob1:1,hold_prob2:1})
        #some code
        return(the_results)
    

    文件内容2.py:

    #import some packages
    import tensorflow as tf
    
    def func2(parameters):
    
        #some code
    
        config = tf.ConfigProto()
        config.gpu_options.allow_growth=True
        sess = tf.Session(config=config)
        with gfile.GFile('path the model2', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
        sess.run(tf.global_variables_initializer())
        #Get the needed tensors 
        input_img = sess.graph.get_tensor_by_name('Placeholder:0')
        output_cls_prob = sess.graph.get_tensor_by_name('Reshape_2:0')
        output_box_pred = sess.graph.get_tensor_by_name('rpn_bbox_pred/Reshape_1:0')
        #some code to prepare and resize the image
        cls_prob, box_pred = sess.run([output_cls_prob, output_box_pred], feed_dict={input_img: blobs['data']})
        #some code
        return(the_results)
    

    运行main.py时,出现以下错误:

    Traceback (most recent call last):
      File "main.py", line 46, in <module>
        func2(parameters)
      File "/home/hani/opti/libs/sub_lib2/file2.py", line 76, in func2
        cls_prob, box_pred = sess.run([output_cls_prob, output_box_pred], feed_dict={input_img: blobs['data']})
      File "/home/hani/.virtualenvs/opti/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
        run_metadata_ptr)
      File "/home/hani/.virtualenvs/opti/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1128, in _run
        str(subfeed_t.get_shape())))
    ValueError: Cannot feed value of shape (1, 600, 863, 3) for Tensor 'Placeholder:0', which has shape '(?, 227, 227, 3)'
    

    经过一些调试,我在第二个模型中没有找到任何具有(?)的张量。,227,227,3)形状。相反,我发现张量x(由 x = tf.placeholder(tf.float32,shape=[None,IMG_SIZE_ALEXNET,IMG_SIZE_ALEXNET,3]) 在func1文件1)中有(?,227,227,3)形状。 我检查了输入的形状 input_img = sess.graph.get_tensor_by_name('Placeholder:0') 在来自文件2的func2中,我找到了它(?,227,227,3)当我运行main.py时。但是,当我运行file2.py时(通过运行 python file2.py ,我没有得到这个错误,我发现输入的形状是占位符形状:(???,3)。
    所以我假设这两个模型都有相同的张量名称( 占位符 )当我在main.py中同时导入file1和file2时,占位符的第一个形状(?,227,227,3)保留在GPU内存中。
    我试过 session.close() 在file1.py中,但它不起作用!
    在同一进程中使用多个TensorFlow会话,而不会混淆它们,是否有更合适的方法?或者简单地说,如何在同一个Python进程中启动另一个前正确关闭TensorFlow会话?

    1 回复  |  直到 6 年前
        1
  •  1
  •   singrium    6 年前

    在阅读了堆栈溢出中的一些相关文章之后,我发现了 this answer 从中我引用:

    在第二次构建过程中,您可能会因为尝试 创建具有相同名称的变量(在您的案例中会发生什么); 正在定稿的图表等。

    为了解决我的问题,我只需要添加 tf.reset_default_graph() 到main.py,以便重置图形及其参数。

    推荐文章