【发布时间】:2019-04-12 13:38:10
【问题描述】:
在 TensorFlow 1.x 中,要更新张量,我会使用 tf.scatter_update,仅更新张量的相关部分。
我们如何在 TF 2.0 中做同样的事情?
【问题讨论】:
标签: python tensorflow tensorflow2.0
在 TensorFlow 1.x 中,要更新张量,我会使用 tf.scatter_update,仅更新张量的相关部分。
我们如何在 TF 2.0 中做同样的事情?
【问题讨论】:
标签: python tensorflow tensorflow2.0
你可以使用tf.tensor_scatter_nd_update():
import tensorflow as tf
import numpy as np
tensor = tf.convert_to_tensor(np.ones((2, 2)), dtype=tf.float32)
indices = tf.constant([[0, 0]])
updates = tf.constant([0.0])
tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
# array([[0., 1.],
# [1., 1.]], dtype=float32)
【讨论】: