【问题标题】:tensorflow loss minimization type error张量流损失最小化类型错误
【发布时间】:2015-11-13 21:42:58
【问题描述】:

我在 TensorFlow 中实现了一个计算均方误差的损失函数。所有用于计算目标的张量都是 float64 类型,因此损失函数本身是 dtype float64。特别是,

print cost
==> Tensor("add_5:0", shape=TensorShape([]), dtype=float64)

但是,当我尝试最小化时,我得到一个关于张量类型的值错误:

GradientDescentOptimizer(learning_rate=0.1).minimize(cost)
==> ValueError: Invalid type <dtype: 'float64'> for add_5:0, expected: [tf.float32].

当导致计算的所有变量都是 float64 类型时,我不明白为什么张量的预期 dtype 是单精度浮点数。我已经确认,当我将所有变量强制为 float32 时,计算会正确执行。

有没有人知道为什么会发生这种情况?我的电脑是64位机器。

这是一个重现行为的示例

import tensorflow as tf
import numpy as np

# Make 100 phony data points in NumPy.
x_data = np.random.rand(2, 100) # Random input
y_data = np.dot([0.100, 0.200], x_data) + 0.300

# Construct a linear model.
b = tf.Variable(tf.zeros([1], dtype=np.float64))
W = tf.Variable(tf.random_uniform([1, 2], minval=-1.0, maxval=1.0, dtype=np.float64))
y = tf.matmul(W, x_data) + b

# Minimize the squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# For initializing the variables.
init = tf.initialize_all_variables()

# Launch the graph
sess = tf.Session()
sess.run(init)

# Fit the plane.
for step in xrange(0, 201):
    sess.run(train)
    if step % 20 == 0:
        print step, sess.run(W), sess.run(b)

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    目前 tf.train.GradientDescentOptimizer 类仅在 supports 上训练 32 位浮点变量和损失值。

    但是,内核似乎是为双精度值实现的,因此应该可以在您的场景中进行训练。

    一种快速的解决方法是定义一个同时支持tf.float64 值的子类:

    class DoubleGDOptimizer(tf.train.GradientDescentOptimizer):
      def _valid_dtypes(self):
        return set([tf.float32, tf.float64])
    

    ...然后使用 DoubleGDOptimizer 代替 tf.train.GradientDescentOptimizer

    编辑:您需要将学习率传递为tf.constant(learning_rate, tf.float64) 才能完成这项工作。

    (注意这不是一个受支持的接口,将来可能会发生变化,但团队意识到优化双精度浮点数的愿望,并打算提供一个内置的-在解决方案中。)

    【讨论】:

    • 现在似乎不起作用(tf v0.6)。 TypeError: Input 'alpha' of 'ApplyGradientDescent' Op has type float32 that does not match type float64 of argument 'var'.
    • 感谢您指出这一点。我修改了答案。
    • 这似乎不适用于像 ADAM 这样的其他(更复杂的)优化器......
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-12-07
    • 2018-12-16
    • 1970-01-01
    • 2018-02-03
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多