代码之家  ›  专栏  ›  技术社区  ›  pfc

如何计算从pb文件加载的tensorflow模型的flops

  •  4
  • pfc  · 技术社区  · 6 年前

    import tensorflow as tf
    import sys
    from tensorflow.python.platform import gfile
    
    from tensorflow.core.protobuf import saved_model_pb2
    from tensorflow.python.util import compat
    
    pb_file = 'themodel.pb'
    
    run_meta = tf.RunMetadata()
    with tf.Session() as sess:
        print("load graph")
        with gfile.FastGFile(pb_path,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
            flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
                options=tf.profiler.ProfileOptionBuilder.float_operation())
            print("test flops:{:,}".format(flops.total_float_ops))
    

    打印信息很奇怪。我的模型有几十层,但它报告的打印信息只有18次失败。我很确定模型是正确加载的,因为如果我尝试按如下方式打印每个层的名称:

    print([n.name for n in tf.get_default_graph().as_graph_def().node])
    

    打印信息正好显示了正确的网络。

    我的代码怎么了?

    2 回复  |  直到 6 年前
        1
  •  1
  •   pfc    5 年前

    我想我找到了问题的原因和解决办法。下面的代码可以打印给定pb文件的触发器。

    import os
    import tensorflow as tf
    from tensorflow.core.framework import graph_pb2
    from tensorflow.python.framework import importer
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    
    pb_path = 'mymodel.pb'
    
    run_meta = tf.RunMetadata()
    with tf.Graph().as_default():
        output_graph_def = graph_pb2.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = importer.import_graph_def(output_graph_def, name="")
            print('model loaded!')
        all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
        # for k in all_keys:
        #   print(k)
    
        with tf.Session() as sess:
            flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
                options=tf.profiler.ProfileOptionBuilder.float_operation())
            print("test flops:{:,}".format(flops.total_float_ops))
    

    问题中打印的flops仅为18的原因是,在生成pb文件时,我将输入图像形状设置为 [None, None, 3] . 如果我改成 [500, 500, 3] ,则打印的触发器将是正确的。

        2
  •  0
  •   Allen Lavoie    6 年前

    在不知道输入和输出的情况下,不确定它将如何计算任何性能度量:也许它需要 CallableOptions trace_next_step and a Session 而不是手动计算。