代码之家  ›  专栏  ›  技术社区  ›  Yanjun

如何计算由加载的图形定义的tensorflow模型中可训练参数的总数。pb文件?

  •  2
  • Yanjun  · 技术社区  · 6 年前

    我想计算tensorflow模型中的参数。这与现有问题类似,如下所示。

    How to count total number of trainable parameters in a tensorflow model?

    但如果模型是使用加载自的图形定义的。pb文件,所有建议的答案都不起作用。基本上,我用以下函数加载了图形。

    def load_graph(model_file):
    
      graph = tf.Graph()
      graph_def = tf.GraphDef()
    
      with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
    
      with graph.as_default():
        tf.import_graph_def(graph_def)
    
      return graph
    

    一个示例是加载冻结的\u图。tensorflow-for-poets-2中用于再培训的pb文件。

    https://github.com/googlecodelabs/tensorflow-for-poets-2

    1 回复  |  直到 6 年前
        1
  •  1
  •   Y. Luo    6 年前

    据我所知 GraphDef 没有足够的信息来描述 Variables .如上所述 here ,您将需要 MetaGraph ,其中包含 GraphDef公司 CollectionDef 这是一张可以描述 变量 因此,以下代码应为我们提供正确的可训练变量计数。

    导出元图形:

    import tensorflow as tf
    
    a = tf.get_variable('a', shape=[1])
    b = tf.get_variable('b', shape=[1], trainable=False)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver([a])
    
    with tf.Session() as sess:
        sess.run(init)
        saver.save(sess, r'.\test')
    

    导入元图并计算可训练参数的总数。

    import tensorflow as tf
    
    saver = tf.train.import_meta_graph('test.meta')
    
    with tf.Session() as sess:
        saver.restore(sess, 'test')
    
    total_parameters = 0
    for variable in tf.trainable_variables():
        total_parameters += 1
    print(total_parameters)