【发布时间】:2019-04-10 22:18:40
【问题描述】:
我正在尝试在不使用 Python 的情况下用 Java 制作 TensorFlow 模型。 我设法为 Java 编写了很多 Python 代码,但我缺少一些需要完成的元素。 我正在阻止优化器。 Python中的原始代码是一个非常简单的模型。
import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
我开始了向 Java 的转换,距离结束不远,但我卡在了优化器上。
try (Graph g = new Graph()) {
//# Batch of input and target output (1x1 matrices)
//x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
Output<OperationBuilder> x = g.opBuilder("Placeholder", "input")
.setAttr("dtype", DataType.FLOAT)
.build().output(0);
//y = tf.placeholder(tf.float32, shape=[None, 1, 1], name=target')
Output<OperationBuilder> y = g.opBuilder("Placeholder", "target")
.setAttr("dtype", DataType.FLOAT)
.build().output(0);
//# Trivial linear model
//y_ = tf.identity(tf.layers.dense(x, 1), name='output')
Tensor t = Tensor.create(new int[] {0});
Output reductionIndices = g.opBuilder("Const", "layer")
.setAttr("dtype", t.dataType()).setAttr("value", t)
.build().output(0);
Output dense = g.opBuilder("layersdense", "dense")
.setAttr("T", DataType.FLOAT)
.setAttr("Tidx", DataType.INT32)
.addInput(input).addInput(reductionIndices)
.build().output(0);
Tensor<?> t2 = Tensor.create(dense);
Output<OperationBuilder> y_ = g.opBuilder("Identity", "output")
.setAttr("value", t2)
.build().output(0);
//# Optimize loss
//loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
Output<OperationBuilder> sub=g.opBuilder("Sub","sub")
.addInput(y_).addInput(y)
.build().output(0);
Output<OperationBuilder> sq = g.opBuilder("Square", "Square")
.addInput(sub)
.build().output(0);
//optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
Code java ???
//train_op = optimizer.minimize(loss, name='train')
Code java ???
}
【问题讨论】:
-
相关:stackoverflow.com/questions/46030577/… 。上次我检查时,Java api 只允许加载经过训练的模型并进行预测,这样您就可以在 Android-Java-App 中轻松使用用 Python 训练的神经网络。您缺少的“一些元素”似乎是整个优化框架。由于 Tensor-Flow 严重依赖于用 C/C++ 编写的高度优化的低级代码,因此将整个东西移植到 Java 似乎并不容易,甚至是不可取的。
-
Andrey 是正确的,因为优化器库在 Java 中不可用。但是,这些库本质上是添加到 TensorFlow 图,因此您可以在 Java 中复制它们(无需移植任何 C/C++ 代码)。但是,一般来说,如果可能的话,您最好在 Python 中创建模型,然后从 Java 中执行训练循环。例如参见:github.com/tensorflow/models/tree/master/samples/languages/java/…
-
感谢您的回答。我认为所有的计算函数都在 tensorflow_jni.dll 中并用 JNI 调用。可以叫“Placeholder、Const、Sub、Square……”,我以为我们叫PYTHON的所有函数也是JAVA。
-
许多 Python 函数是在图中封装了多个操作的复合函数。例如,
tf.layers.dense函数添加了矩阵乘法、偏置加法和激活函数 - 因此向图中添加了多个节点。这些高级便利尚未在其他语言中实现,因此您必须单独添加每个操作。长话短说,截至 2018 年 1 月,我的建议是从 Python 程序创建和导出模型,即使您使用 Java 驱动训练过程。 -
好的,谢谢。我将安装 PYTHON 并使用它来生成图表。
标签: java tensorflow