【问题标题】:Get nonzeros row of a SparseTensor获取 SparseTensor 的非零行
【发布时间】:2017-03-08 20:58:46
【问题描述】:

我想从 SparseTensor 的一行中获取所有非零值,因此“m”是我拥有的稀疏张量对象,而 row 是我想从中获取所有非零值和索引的行。所以我想返回一个数组,它是 [(index, values)]。我希望我能在这个问题上得到一些帮助。

def nonzeros( m, row):
    res = []
    indices = m.indices
    values = m.values
    userindices = tf.where(tf.equal(indices[:,0], tf.constant(0, dtype=tf.int64)))
    res = tf.map_fn(lambda index:(indices[index][1], values[index]), userindices)
    return res

终端中的错误消息

TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'.

编辑: 非零输入 cm 是一个带有值的 coo_matrix

m = tf.SparseTensor(indices=np.array([row,col]).T,
                        values=cm.data,
                        dense_shape=[10, 10])
nonzeros(m, 1)

如果数据是

[[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  2.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]]

结果应该是

[index, value]
[4,1]
[9,2]

【问题讨论】:

  • 你能报告一个输入输出的例子吗?这样我们才能更好地了解如何得到你想要的。

标签: tensorflow sparse-matrix


【解决方案1】:

问题是 lambda 中的 index 是一个张量,你不能直接使用它来索引例如indices。您可以改用tf.gather。另外,您发布的代码中没有使用row 参数。

试试这个:

import tensorflow as tf
import numpy as np

def nonzeros(m, row):
    indices = m.indices
    values = m.values
    userindices = tf.where(tf.equal(indices[:, 0], row))
    found_idx = tf.gather(indices, userindices)[:, 0, 1]
    found_vals = tf.gather(values, userindices)[:, 0:1]
    res = tf.concat(1, [tf.expand_dims(tf.cast(found_idx, tf.float64), -1), found_vals])
    return res

data = np.array([[0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
                [0., 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  2.]])

m = tf.SparseTensor(indices=np.array([[0, 1], [0, 9], [1, 4], [1, 9]]),
                    values=np.array([1.0, 1.0, 1.0, 2.0]),
                    shape=[2, 10])

with tf.Session() as sess:
    result = nonzeros(m, 1)
    print(sess.run(result))

哪个打印:

[[ 4.  1.]
 [ 9.  2.]]

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-12-27
    • 2021-01-26
    • 1970-01-01
    • 2021-11-25
    相关资源
    最近更新 更多