【问题标题】:Tensorflow Mobile - Inference InterfaceTensorflow Mobile - 推理接口
【发布时间】:2018-05-11 14:25:53
【问题描述】:

我对 TensorflowInferenceInterface 有一些问题。加载张量流模型后,我需要输出节点的形状尺寸。在 Tensorboard 仪表板中,正确表示以下形状:[?,?,?,2048] 因为动态图像输入。我已经为推理过程执行了这些指令:


TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_NAME);
Operation operation = inferenceInterface.graph().operation(OUTPUT_NAME);

inferenceInterface.feed(INPUT_NAME, floatValues, 1, bitmapImage.getWidth(), bitmapImage.getHeight(), 3);

String[] outputNames =  new String[] {"activation_43/Relu"};
    inferenceInterface.run(outputNames, true);

float[] outputs = new float[690000];

inferenceInterface.fetch(OUTPUT_NAME, outputs);

我的问题是:如何获得输出张量维度?我需要形状向量中每个字段的 int 值而不是问号。

【问题讨论】:

    标签: java android tensorflow


    【解决方案1】:

    通常在进行推理时,输出形状都是静态的。似乎 TensorflowInferenceInterface 的设计考虑到了这种情况,并且没有办法获得输出形状。

    一种选择是复制粘贴源代码(它相当小,只使用其他公共接口)并公开getTensor()。然后,你可以这样做:

    Tensor<?> t = inferenceInterface.getTensor(OUTPUT_NAME);
    long[] dimensions = t.shape();
    

    或者,您可以提交描述您的用例的 github 问题。人们可能会觉得这是合理的并接受 PR 以更改实际的 getTensor() 方法可见性或提出替代方案。

    【讨论】:

    • 这是 TensorflowInferenceInterface 内部的私有方法...我该怎么办?
    • 如何覆盖该方法以使其公开?我已经复制粘贴了 TensorflowInferenceInterface 的源代码,并且我已将可见性从私有修改为公共,但我得到了如下编译错误:“getTensor() 在 org.tensorflow.contrib.android.TensorflowInferenceInterface 中具有私有访问权限”。该库链接在 build.gradle 文件中。是那个问题吗?你能帮助我吗?谢谢@iga
    • @M.Armao,您似乎没有得到您复制粘贴的代码。仅供参考,如果有多个具有相同名称的类,Java 类加载器在技术上可以选择任何一个。通常会选择最后一个加载的。您可以重命名它并在代码中使用新的类名,以确保获得正确的类。
    猜你喜欢
    • 1970-01-01
    • 2017-02-21
    • 2018-10-18
    • 2018-05-27
    • 1970-01-01
    • 1970-01-01
    • 2018-11-08
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多