【问题标题】:How to parse String output of a tensorflow model如何解析张量流模型的字符串输出
【发布时间】:2017-11-15 23:51:12
【问题描述】:

使用此处的代码创建模型:https://gist.github.com/gaganmalhotra/1424bd3d0617e784976b29d5846b16b1

要在 java 中获得概率的预测,可以使用以下代码来完成:

  public static void main(String[] args) {
    Session session = SavedModelBundle.load("/Users/gagandeep.malhotra/Documents/SampleTF_projects/tf_iris_model/1510707746/", "serve").session();

    Tensor x =
        Tensor.create(
            new long[] {2, 4},
            FloatBuffer.wrap(
                new float[] {
                  6.4f, 3.2f, 4.5f, 1.5f,
                  5.8f, 3.1f, 5.0f, 1.7f
                }));

    final String xName = "Placeholder:0";
    final String scoresName = "dnn/head/predictions/probabilities:0";

    List<Tensor<?>> outputs = session.runner()
        .feed(xName, x)
        .fetch(scoresName)
        .run();

    // Outer dimension is batch size; inner dimension is number of classes
    float[][] scores = new float[2][3];

    outputs.get(0).copyTo(scores);
    System.out.println(Arrays.deepToString(scores));
  }

但是,如果我们想复制下面代码的预测类(字符串类型):

final String xName = "Placeholder:0";
final String className = "dnn/head/predictions/str_classes:0";

List<Tensor<?>> outputs = session.runner()
    .feed(xName, x)
    .fetch(className)
    .run();

// Outer dimension is batch size; inner dimension is number of classes
String[][] classes = new String[2][1];

outputs.get(0).copyTo(classes);
System.out.println(Arrays.deepToString(classes));

我最终会遇到这样的错误:

Exception in thread "main" java.lang.IllegalArgumentException: cannot copy Tensor with 2 dimensions into an object with 1
    at org.tensorflow.Tensor.throwExceptionIfTypeIsIncompatible(Tensor.java:739)
    at org.tensorflow.Tensor.copyTo(Tensor.java:450)
    at deeplearning.IrisTFLoad.main(IrisTFLoad.java:71)

但维度与输出张量相同:[STRING tensor with shape [2, 1]]

PS:签名定义如下 -

The given SavedModel SignatureDef contains the following input(s):
    inputs['x'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 4)
        name: Placeholder:0
    The given SavedModel SignatureDef contains the following output(s):
    outputs['class_ids'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 1)
        name: dnn/head/predictions/ExpandDims:0
    outputs['classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: dnn/head/predictions/str_classes:0
    outputs['logits'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 3)
        name: dnn/head/logits:0
    outputs['probabilities'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 3)
        name: dnn/head/predictions/probabilities:0
    Method name is: tensorflow/serving/predict

尝试过的东西:

张量张量=(张量)输出.get(0); byte[][][] result = tensor.copyTo(new byte[2][1][]);

但错误如下:

Exception in thread "main" java.lang.IllegalStateException: invalid DataType(7)
    at org.tensorflow.Tensor.readNDArray(Native Method)
    at org.tensorflow.Tensor.copyTo(Tensor.java:451)
    at deeplearning.IrisTFLoad.main(IrisTFLoad.java:74)

【问题讨论】:

    标签: java python tensorflow tensorflow-serving


    【解决方案1】:

    DT_STRING 类型化的 TensorFlow 张量包含 arbitrary byte sequences 作为元素,而不是 Java Strings(字符序列)。

    因此,你想要的是这样的:

    byte[][][] classes = new byte[2][1][];
    outputs.get(0).copyTo(classes);
    

    如果您想获取 Java String 对象,那么您需要知道您的模型生成类的编码是什么,然后可以执行类似的操作(假设为 UTF-8 编码):

    String[][] classesStrings = new String[2][1];
    for (int i = 0; i < classes.length; ++i) {
      for (int j = 0; j < classes[i].length; ++j) {
        classesString[i][j] = new String(classes[i][j], UTF_8);
      }
    }
    

    希望对您有所帮助。 您可能还会发现 unittest 很有指导意义。

    【讨论】:

      猜你喜欢
      • 2018-12-22
      • 1970-01-01
      • 2018-06-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-10-07
      • 2018-05-03
      相关资源
      最近更新 更多