代码之家  ›  专栏  ›  技术社区  ›  gagan malhotra

使用估计器在java中加载/服务tensorflow模型时的问题

  •  2
  • gagan malhotra  · 技术社区  · 7 年前

    我使用了人口普查数据,并使用tensorflow中的估计器api创建了一个广泛而深入的模型。在Java中加载模型时,似乎出现了一个错误,无法加载模型。异常看起来像

    Exception in thread "main" org.tensorflow.TensorFlowException: Op type not 
    registered 'SparseFeatureCross' in binary running on gmalhotra-mba-2.local. 
    Make sure the Op and Kernel are registered in the binary running in this 
    process.
    at org.tensorflow.SavedModelBundle.load(Native Method)
    at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:39)
    at deeplearning.DeepLearningTest.main(DeepLearningTest.java:32)
    

    请查找以下用于保存模型的python代码: https://gist.github.com/gaganmalhotra/cd6a5898b9caf9005a05c8831a9b9153

    使用的Java代码如下:

        public static void main(String[] args) {
              try (SavedModelBundle b = SavedModelBundle.load("/Users/gagandeep.malhotra/Documents/SampleTF_projects/temporaryModel/1510624417/", "serve")) {
    
    
        Session sess = b.session();
    
                    //Create the input sensor 
                      float[][] mat=new float[1][1];
                      mat[0]=new float[]{0.5f};
    
                    // create tensors specific to inputs ....
    
                    Tensor<?> x = (Tensor<?>) Tensor.create(mat);
    
                    //run the model 
                    float[][] y = sess.runner()
                            .feed("input", x)
                            .fetch("output")
                            .run()
                            .get(0)
                            .copyTo(new float[1][1]);               
    
                   //print the result
                    System.out.println(y[0][0]);
    }
    

    PS:使用的Tensorflow版本:1.3

    1 回复  |  直到 7 年前
        1
  •  6
  •   ash    7 年前

    在中使用操作时 tf.contrib 模块,它们不被认为是实验性的,因此不属于 stable TensorFlow API 并且不包括在其他语言发行版中。

    然而,在TensorFlow 1.4及更高版本中,您可以使用 TensorFlow.loadLibrary() .

    为此,首先需要找到包含实现的共享库的位置 tf。contrib公司 tf.contrib.layers ,所以你可以这样做:

    python -c "import tensorflow; print(tensorflow.contrib.layers.__path__)"
    

    打印内容如下:

    ['/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/layers']
    

    然后,您可以使用以下内容找到该路径中的所有共享库:

    find /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/layers -name "*.so"
    

    这将类似于:

    /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/layers/python/ops/_sparse_feature_cross_op.so
    

    好的,现在您有了这个库,您可以使用以下方法在Java中加载它:

    public static void main(String[] args) {
        TensorFlow.loadLibrary("/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/layers/python/ops/_sparse_feature_cross_op.so");
    
        // And now load the model etc.
    }
    

    • 如果你想在不同的机器上运行,你需要打包 .so 用您的程序归档以上内容,并将调用调整为 TensorFlow。loadLibrary() 适当。

    • 确保您对Python和Java(1.4)使用相同的TensorFlow版本

    希望这有帮助。