代码之家  ›  专栏  ›  技术社区  ›  2adnielsenx xx

vgg16 keras保存模型以导出c++

  •  1
  • 2adnielsenx xx  · 技术社区  · 6 年前

    我最近正在研究vgg16以重用其模型。

    在python中:(Keras)

    model = applications.VGG16(include_top=False, weights='imagenet')
    

    一切都很好。

    我需要用compile和fit导出这个模型,以导出c++json文件。 如何正确导出h5文件以用于导出模型?

    1 回复  |  直到 6 年前
        1
  •  2
  •   marcopah    6 年前

    一种方法是在Python中将Keras模型转换为TensorFlow模型,然后将冻结图导出到 .pb 文件然后用C++加载。我用这段代码导出了一个冻结的 PB 来自Keras的文件。

    import tensorflow as tf
    
    from keras import backend as K
    from tensorflow.python.framework import graph_util
    
    K.set_learning_phase(0)
    model = function_that_returns_your_keras_model()
    sess = K.get_session()
    
    output_node_name = "my_output_node" # Name of your output node
    
    with sess as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        graph_def = sess.graph.as_graph_def()
        output_graph_def = graph_util.convert_variables_to_constants(
                                                                     sess,
                                                                     sess.graph.as_graph_def(),
                                                                     output_node_name.split(","))
        tf.train.write_graph(output_graph_def,
                             logdir="my_dir",
                             name="my_model.pb",
                             as_text=False)
    

    然后,您可以按照任何有关如何加载的教程进行操作 PB C++中的文件。例如: https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f

    Keras在TensorFlow图中注入learning\u phase变量,还可能注入其他Keras专用变量-如果我没记错的话,您应该确保从图中删除这些变量。