【问题标题】:How to perform PyTorch style tensor slice update in TensorFlow?如何在 TensorFlow 中执行 PyTorch 风格的张量切片更新?
【发布时间】:2023-04-07 11:22:02
【问题描述】:

在 Pytorch 中,您可以像这样轻松更新张量:

 for i in range(x_len):
     tensor_abc[:, i, i] = 0

我们如何在 tensorflow 中更新这样的张量?

我尝试了tf.assigntf.scatter_update,但都不起作用。

【问题讨论】:

    标签: tensorflow slice tensor


    【解决方案1】:

    此答案仅适用于变量。

    import tensorflow as tf
    
    sess = tf.InteractiveSession()
    v = tf.zeros((5,5,5))
    var = tf.Variable(initial_value=v)
    
    
    init = tf.variables_initializer([var])
    sess.run(init)
    
    
    var = var[ 1 : 2 ,
               1 : 2 ,
               1 : 2 ].assign(tf.ones((1,1,1)))
    
    print(sess.run(var))
    

    这会产生

    [[[0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
     [[0. 0. 0. 0. 0.]
      [0. 1. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
     [[0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
     [[0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
     [[0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]]
    

    还有这个

    var = var[ 1 : 2 ,
               0 : 1 ,
               0 : 1 ].assign(tf.ones((1,1,1)))
    

    生产

      [[[0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
     [[1. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
      ....
      ....]]
    

    另一个例子是

    var = var[ 1 : 2 ,
                 : 2 ,
                 : 2 ].assign(tf.ones((1,2,2)))
    
    [[[0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
     [[1. 1. 0. 0. 0.]
      [1. 1. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]
      [0. 0. 0. 0. 0.]]
    
          ....
          ....]]
    

    您应该探索 tf.scatter_nd 的张量。

    【讨论】:

      【解决方案2】:

      tf.Variable 是唯一可以更新的张量。对于变量,您可以使用 gatherscatter_update 之类的代码进行切片。

      请注意,其他张量不适合赋值。如果这是你想要做的,我想知道为什么它是必要的。但是,仍然可以使用您想要的值(而不是就地分配)创建新的张量,代码有点复杂。例如,以下内容不起作用:

      index = ... tensor = tf.constant([0,1,2,3,4]) 
      tensor[i] = 0  
      ## Doesn't work (TypeError: `Tensor` object does not support item assignment)
      

      但其中任何一个都可以做同样的事情:

      tensor = tf.constant([0,1,2,3,4]) 
      tensor = tf.concat([tensor[:i], tf.zeros_like(tensor[i:i+1]), tensor[i+1:]], 0)  
      ## This works, creates a new tensor
      

      张量 = tf.constant([0,1,2,3,4]) 张量 = tf.concat([tensor[:i], tf.fill([1], 0), tensor[i+1:]], 0) ## 这行得通,创建一个新的张量

      【讨论】:

      • 创建新张量可以,但是多维张量可以吗?例如将下面张量的所有对角线更新为 0: tensor = tf.Variable([ [[1,2,3],[4,5,6],[7,8,9] ], [[3,2, 1],[6,5,4],[9,8,7] ], [[2,2,2],[2,2,2],[2,2,2] ] ])
      • 如果你使用变量,你可以使用tf.scatter_update而不是创建新的张量。对于您的具体示例,“所有对角线”是什么意思? (你指的是哪个轴?)你能告诉我一个所需结果矩阵的例子吗?这将允许我提供更具体的代码示例
      猜你喜欢
      • 2020-08-01
      • 1970-01-01
      • 2016-05-22
      • 2019-10-21
      • 2018-02-24
      • 1970-01-01
      • 2018-05-07
      • 1970-01-01
      • 2020-02-16
      相关资源
      最近更新 更多