代码之家  ›  专栏  ›  技术社区  ›  JarsOfJam-Scheduler

在{TF 2.0.0-beta1;Keras}模型上使用tflite_convert时的“未知(自定义)丢失函数”

  •  0
  • JarsOfJam-Scheduler  · 技术社区  · 5 年前

    摘要

    • 我展示我的项目、工作环境和工作流程的上下文
    • 我的代码中有关的部分
    • 我试图解决问题的方法

    上下文

    我已经编写了一个Python Keras实现,它是原始超分辨率GAN的降级版本。现在我想用Google的Firebase机器学习工具包来测试它,把它放在Google服务器上。这就是为什么我必须把我的Keras程序转换成TensorFlow Lite程序的原因。

    环境和工作流(有问题)

    我正在Google Colab工作环境上培训我的程序:在那里,我已经安装了 TF 2.0.0-beta1 https://datascience.stackexchange.com/a/57408/78409 ).

    1. 我在本地编写Python Keras程序,并记住它将在TF 2上运行。所以我使用TF 2导入,例如: from tensorflow.keras.optimizers import Adam 还有 from tensorflow.keras.layers import Conv2D, BatchNormalization

    2. 我运行没有任何问题我的Google Colab笔记本:TF 2是使用。

    3. 我在驱动器中获取输出模型,然后下载它。

    4. 我尝试通过执行以下CLI将此模型转换为TFLite格式: tflite_convert --output_file=srgan.tflite --keras_model_file=srgan.h5 问题出现在这里 .

    问题

    值错误:未知损耗函数:建立vgg19损耗网络

    功能 build_vgg19_loss_network

    引起这个问题的部分代码

    提供自定义丢失功能

    自定义丢失函数的实现方式如下:

    def build_vgg19_loss_network(ground_truth_image, predicted_image):
        loss_model = Vgg19Loss.define_loss_model(high_resolution_shape)
        return mean(square(loss_model(ground_truth_image) - loss_model(predicted_image)))
    

    generator_model.compile(optimizer=the_optimizer, loss=build_vgg19_loss_network)

    我为解决这个问题所做的努力

    1. 当我在StackOverflow(本问题开头的link)上阅读时,TF 2被认为足以输出一个Keras模型,该模型将由我的 tflite_convert 克利。但显然不是。

    2. 在GitHub上阅读时,我试图在Keras的loss函数中手动设置自定义loss函数,方法是添加以下几行: import tensorflow.keras.losses tensorflow.keras.losses.build_vgg19_loss_network = build_vgg19_loss_network

    3. 我在GitHub上读到我可以使用自定义对象 load_model 凯拉斯函数:但我只想使用 compile Keras函数。不是 装载模式 .

    我的最后一个问题

    编译 具有 . 有了这个限制,你能帮我做个命令行吗 tflite_转换

    0 回复  |  直到 5 年前
        1
  •  2
  •   Prasad    5 年前

    由于您声称TFLite转换由于自定义丢失函数而失败,因此可以保存模型文件而不保留优化器的详细信息。这样做,设置 include_optimizer 参数设置为False,如下所示:

    model.save('model.h5', include_optimizer=False)
    

    编辑: 然后可以这样转换h5文件:

    import tensorflow as tf
    
    model = tf.keras.models.load_model('model.h5')   # srgan.h5 for you
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)
    

    克服TFLite转换中不支持的运算符的通常做法是 documented here .