【问题标题】:tape = tape if tape is not None else backprop.GradientTape()磁带 = 磁带 如果磁带不是 无 其他 backprop.GradientTape()
【发布时间】:2021-07-28 02:49:37
【问题描述】:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
def line(x):
  return 2*x+4
X = np.arange(0,20)
y = [k for k in line(X)]
a = tf.Variable(1.0)
b = tf.Variable(0.2)

y_in = a*X + b
loss = tf.reduce_mean(tf.square(y_in - y))
#this is my old code
#optimizer = tf.train.GradientDescentOptimizer(0.2)
#train = optimizer.minimize(loss)

#new Code
optimizer = tf.optimizers.SGD (0.2)
train = optimizer.minimize(loss,var_list=[a,b])

///错误

ValueError Traceback(最近一次调用最后一次) 在 () ----> 1 列车 = optimizer.minimize(loss,var_list=[a,b])

1 帧 /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in _compute_gradients(self, loss, var_list, grad_loss, tape) 530 # TODO(josh11b): 测试我们是否以合理的方式处理权重衰减。 531 如果不可调用(丢失)并且磁带为无: --> 532 raise ValueError("tape is required when a Tensor loss is passed.") 第533章 534

ValueError: tape 在传递 Tensor 损失时是必需的。

【问题讨论】:

  • 仅供参考,如果没有任何有效值可以是假的,那么将其写为tape or backprop.GradientTape() 可能会更容易一些。

标签: python numpy tensorflow regression linear-regression


【解决方案1】:

你还有更多的路要走!您需要计算 grad ,然后使用优化器更改变量。我修改了你的代码。另外,你的损失函数也不好用。

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

def line(x):
  return 2*x+4
X = np.arange(0,20)
y = tf.constant(np.array([k for k in line(X)], dtype=np.float32))
a = tf.Variable(1.0, trainable=True)
b = tf.Variable(0.2, trainable=True)

def objective_fun(X):
    y_in = a * X + b
    return y_in


def loss_fun(y_true, y_pred):
    # loss = tf.reduce_mean(tf.square(y_true - y_pred))
    loss = tf.reduce_mean(tf.abs(y_pred - y_true))
    return loss

optimizer = tf.optimizers.SGD (0.01)

MAX_ITER = 1000
for it in range(MAX_ITER):
    with tf.GradientTape() as tape:
        y_pred = objective_fun(X)
        loss = loss_fun(y_pred, y)
    grad = tape.gradient(loss, [a, b])
    optimizer.apply_gradients(zip(grad, [a, b]))
    print(loss.numpy())

这是优化的结果:

a.numpy(), b.numpy()

(1.9880208, 3.8429925)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2018-08-06
    • 2011-05-12
    • 2015-04-24
    • 2018-02-21
    • 2022-01-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多