【问题标题】:Tensorflow Extract Indices Not Equal to ZeroTensorFlow 提取索引不等于零
【发布时间】:2018-02-19 00:43:16
【问题描述】:

我想为每一行返回一个非零索引的密集张量。例如,给定张量:

[0,1,1]
[1,0,0]
[0,0,1]
[0,1,0]

应该返回

[1,2]
[0]
[2]
[1]

我可以使用 tf.where() 获取索引,但我不知道如何根据第一个索引组合结果。例如:

graph = tf.Graph()
with graph.as_default():
    data = tf.constant([[0,1,1],[1,0,0],[0,0,1],[0,1,0]])
    indices = tf.where(tf.not_equal(data,0))
sess = tf.InteractiveSession(graph=graph)
sess.run(tf.local_variables_initializer())
print(sess.run([indices]))

以上代码返回:

[array([[0, 1],
       [0, 2],
       [1, 0],
       [2, 2],
       [3, 1]])]

但是,我想根据这些索引的第一列组合结果。有人可以建议一种方法吗?

更新

试图使其适用于更多维度并遇到错误。如果我在矩阵上运行下面的代码

sess = tf.InteractiveSession()
a = tf.constant([[0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
row_counts = tf.reduce_sum(a, axis=1)
max_padding = tf.reduce_max(row_counts)
extra_padding = max_padding - row_counts
extra_padding_col = tf.expand_dims(extra_padding, 1)
range_row = tf.expand_dims(tf.range(max_padding), 0)
padding_array = tf.cast(tf.tile(range_row, [9, 1])<extra_padding_col, tf.int32)
b = tf.concat([a, padding_array], axis=1)
result = tf.map_fn(lambda x: tf.cast(tf.where(tf.not_equal(x, 0)), tf.int32), b)
result = tf.where(result<=max_padding, result, -1*tf.ones_like(result)) # replace with -1's
result = tf.reshape(result, (int(result.get_shape()[0]), max_padding))
result.eval()

然后我会得到太多的 -1,所以解决方案似乎并不完全存在:

[[ 1,  2],
       [ 2, -1],
       [-1, -1],
       [-1, -1],
       [-1, -1],
       [-1, -1],
       [-1, -1],
       [-1, -1],
       [ 0, -1]]

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    请注意,在您的示例中,输出不是矩阵,而是锯齿状数组。交错数组在TensorFlow中的支持有限(通过TensorArray),所以处理矩形数组更方便。您可以用 -1 填充每一行以使输出为矩形

    假设您的输出已经是矩形,没有填充,您可以使用map_fn,如下所示

    tf.reset_default_graph()
    sess = tf.InteractiveSession()
    a = tf.constant([[0,1,1],[1,1,0],[1,0,1],[1,1,0]])
    # cast needed because map_fn likes to keep same dtype, but tf.where returns int64
    result = tf.map_fn(lambda x: tf.cast(tf.where(tf.not_equal(x, 0)), tf.int32), a)
    # remove extra level of nesting
    sess.run(tf.reshape(result, (4, 2)))
    

    输出是

    array([[1, 2],
           [0, 1],
           [0, 2],
           [0, 1]], dtype=int32)
    

    当需要填充时,你可以这样做

    sess = tf.InteractiveSession()
    a = tf.constant([[0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
    row_counts = tf.reduce_sum(a, axis=1)
    max_padding = tf.reduce_max(row_counts)
    max_index = int(a.get_shape()[1])
    extra_padding = max_padding - row_counts
    extra_padding_col = tf.expand_dims(extra_padding, 1)
    range_row = tf.expand_dims(tf.range(max_padding), 0)
    num_rows = tf.squeeze(tf.shape(a)[0])
    padding_array = tf.cast(tf.tile(range_row, [num_rows, 1])<extra_padding_col, tf.int32)
    b = tf.concat([a, padding_array], axis=1)
    result = tf.map_fn(lambda x: tf.cast(tf.where(tf.not_equal(x, 0)), tf.int32), b)
    result = tf.where(result<max_index, result, -1*tf.ones_like(result)) # replace with -1's
    result = tf.reshape(result, (int(result.get_shape()[0]), max_padding))
    result.eval()
    

    这应该产生

    array([[ 1,  2],
           [ 2, -1],
           [ 4, -1],
           [ 5,  6],
           [ 6, -1],
           [ 7,  9],
           [ 8, -1],
           [ 9, -1],
           [ 0,  9]], dtype=int32)
    

    【讨论】:

    • 这看起来真的很接近!我试图在更大的矩阵上运行它(更新了我上面的问题),但我得到了不正确的结果。我认为它是从 tf.where(result
    • update -- 我很确定结果
    • 你怎么看:result&lt;=tf.squeeze(tf.shape(a)[0])作为更正?
    • 对不起result&lt;=tf.squeeze(tf.shape(a)[0])-1 我相信是正确的
    • 好的,我有一个错误,我的“-1”替换应该是替换大于最大有效索引的值,用固定版本替换代码
    猜你喜欢
    • 2016-07-02
    • 2011-07-16
    • 1970-01-01
    • 2018-05-24
    • 2019-05-05
    • 2016-06-04
    • 2012-12-29
    • 2019-09-06
    • 1970-01-01
    相关资源
    最近更新 更多