【问题标题】:Fill a specific index in tensor with a value用值填充张量中的特定索引
【发布时间】:2018-07-28 06:19:57
【问题描述】:

我是 tensorflow 的初学者。 我创建了这个张量

z = tf.zeros([20,2], tf.float32)

我想将索引z[2,1]z[2,2] 的值更改为1.0 而不是零。 我该怎么做?

【问题讨论】:

    标签: python arrays tensorflow machine-learning variable-assignment


    【解决方案1】:

    针对这个问题的 Tensorflow 2.x 解决方案如下所示:

    import tensorflow as tf
    
    z = tf.zeros([20,2], dtype=tf.float32)
    
    index1 = [2, 0]
    index2 = [2, 1]
    
    result = tf.tensor_scatter_nd_update(z, [index1, index2], [1.0, 1.0])
    tf.print(result)
    
    [[0 0]
     [0 0]
     [1 1]
     ...
     [0 0]
     [0 0]
     [0 0]]
    

    【讨论】:

      【解决方案2】:

      一个更好的方法是使用tf.sparse_to_dense

      tf.sparse_to_dense(sparse_indices=[[0, 0], [1, 2]],
                         output_shape=[3, 4],
                         default_value=0,
                         sparse_values=1,
                         )
      

      输出:

      [[1, 0, 0, 0]
       [0, 0, 1, 0]
       [0, 0, 0, 0]]
      

      但是,tf.sparse_to_dense 最近已被弃用。因此,使用tf.SparseTensor 然后使用tf.sparse.to_dense 得到与上面相同的结果

      【讨论】:

        【解决方案3】:

        一个简单的方法:

        import numpy as np
        import tensorflow as tf
        
        init = np.zeros((20,2), np.float32)
        init[2,1] = 1.0
        z = tf.variable(init)
        

        或使用tf.scatter_update(ref, indices, updates) https://www.tensorflow.org/api_docs/python/tf/scatter_update

        【讨论】:

          【解决方案4】:

          完全问的问题是不可能的,原因有两个:

          • z 是一个常数张量,不能改变。
          • 没有z[2,2],只有z[2,0]z[2,1]

          但假设您想将z 更改为变量并修复索引,则可以这样做:

          z = tf.Variable(tf.zeros([20,2], tf.float32))  # a variable, not a const
          assign21 = tf.assign(z[2, 0], 1.0)             # an op to update z
          assign22 = tf.assign(z[2, 1], 1.0)             # an op to update z
          
          with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print(sess.run(z))                           # prints all zeros
            sess.run([assign21, assign22])
            print(sess.run(z))                           # prints 1.0 in the 3d row
          

          【讨论】:

          • 这适用于 Tensorflow 1。在 TF-2 中运行此代码时:AttributeError: module 'tensorflow' has no attribute 'assign'
          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2019-12-24
          • 1970-01-01
          • 2018-03-24
          • 1970-01-01
          • 2016-02-28
          • 1970-01-01
          相关资源
          最近更新 更多