【问题标题】:TensorFlow 2.0: How to update tensors?TensorFlow 2.0:如何更新张量?
【发布时间】:2019-04-12 13:38:10
【问题描述】:

在 TensorFlow 1.x 中,要更新张量,我会使用 tf.scatter_update,仅更新张量的相关部分。

我们如何在 TF 2.0 中做同样的事情?

【问题讨论】:

    标签: python tensorflow tensorflow2.0


    【解决方案1】:

    你可以使用tf.tensor_scatter_nd_update():

    import tensorflow as tf
    import numpy as np 
    
    tensor = tf.convert_to_tensor(np.ones((2, 2)), dtype=tf.float32)
    indices = tf.constant([[0, 0]])
    updates = tf.constant([0.0])
    
    tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
    # array([[0., 1.],
    #        [1., 1.]], dtype=float32)
    

    【讨论】:

    • 实际上这并没有做我需要的,它创建了一个新对象。有没有办法使用对同一个对象的引用?如果我想更新一些内存变量,例如重播缓冲区,tensorflow 会返回错误:“函数构建代码之外的操作正在传递一个“Graph”张量。”
    • 我找到了自己问题的答案:tensor.assign(tf.tensor_scatter_nd_update(....works.
    • @gpernelle 在您的问题中,您实际上想要更新变量而不是张量对象(这就是 tf.scatter_update 在 1.x 中所做的)。因此,弗拉德的答案在更新张量方面是正确的。你的回答对于更新变量是正确的,虽然我不知道为什么这个功能在 2.x 中消失了。
    猜你喜欢
    • 2018-05-13
    • 2021-02-07
    • 2023-03-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多