【问题标题】:How to update a subset of 2D tensor in Tensorflow?如何在 Tensorflow 中更新二维张量的子集?
【发布时间】:2016-10-04 18:41:12
【问题描述】:

我想更新值为 0 的 2D 张量中的索引。所以 data 是一个 2D 张量,其第 2 行第 2 列索引值将被 0 替换。但是,我收到了类型错误。谁能帮帮我?

TypeError: Input 'ref' of 'ScatterUpdate' Op 需要左值输入

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()

sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])

【问题讨论】:

  • 一般来说,根据更新的复杂程度以及您是否计划将此过程作为大图的一部分,您可能需要查看tf.py_func,这是一个允许您在 tf 图中添加 numpy 操作。

标签: python neural-network tensorflow deep-learning


【解决方案1】:

散点图更新仅适用于变量。请尝试这种模式。

Tensorflow 版本 a = tf.concat(0, [a[:i], [updated_value], a[i+1:]])

Tensorflow 版本 >= 1.0: a = tf.concat(axis=0, values=[a[:i], [updated_value], a[i+1:]])

【讨论】:

  • 非常有帮助,回答了应该问的问题
  • 为什么tensorflow 2没有这个功能?
【解决方案2】:

tf.scatter_update 只能应用于Variable 类型。代码中的dataVariable,而data2 不是,因为tf.reshape 的返回类型是Tensor

解决方案:

用于 v1.0 之后的 tensorflow

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0)
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

用于 v1.0 之前的 tensorflow

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

【讨论】:

  • 它不起作用:ValueError:两个形状中的维度 0 必须相等,但是是 1 和 2 来自将形状 1 与其他形状合并。对于具有输入形状的“concat/concat_dim”(操作:“Pack”):[2]、[1]、[2]。
  • @AntonioSesto TensorFlow v1.0 更改了一些接口以与 numpy 保持一致。 concat 是更改之一。我已经编辑了我的答案以使用 v1.0。
【解决方案3】:

这是我在 Tensorflow 2 中用于修改 2D 张量的子集(行/列)的函数:

#note if updatedValue isVector, updatedValue should be provided in 2D format
def modifyTensorRowColumn(a, isRow, index, updatedValue, isVector):
    
    if(not isRow):
        a = tf.transpose(a)
        if(isVector):
            updatedValue = tf.transpose(updatedValue)
    
    if(index == 0):
        if(isVector):
            values = [updatedValue, a[index+1:]]
        else:
            values = [[updatedValue], a[index+1:]]
    elif(index == a.shape[0]-1):
        if(isVector):
            values = [a[:index], updatedValue]
        else:
            values = [a[:index], [updatedValue]]
    else:
        if(isVector):
            values = [a[:index], updatedValue, a[index+1:]]
        else:
            values = [a[:index], [updatedValue], a[index+1:]]
            
    a = tf.concat(axis=0, values=values)
            
    if(not isRow):
        a = tf.transpose(a)
        
    return a

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-05-13
    • 1970-01-01
    • 2018-11-21
    相关资源
    最近更新 更多