【发布时间】:2017-03-08 20:58:46
【问题描述】:
我想从 SparseTensor 的一行中获取所有非零值,因此“m”是我拥有的稀疏张量对象,而 row 是我想从中获取所有非零值和索引的行。所以我想返回一个数组,它是 [(index, values)]。我希望我能在这个问题上得到一些帮助。
def nonzeros( m, row):
res = []
indices = m.indices
values = m.values
userindices = tf.where(tf.equal(indices[:,0], tf.constant(0, dtype=tf.int64)))
res = tf.map_fn(lambda index:(indices[index][1], values[index]), userindices)
return res
终端中的错误消息
TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'.
编辑: 非零输入 cm 是一个带有值的 coo_matrix
m = tf.SparseTensor(indices=np.array([row,col]).T,
values=cm.data,
dense_shape=[10, 10])
nonzeros(m, 1)
如果数据是
[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 2.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
结果应该是
[index, value]
[4,1]
[9,2]
【问题讨论】:
-
你能报告一个输入输出的例子吗?这样我们才能更好地了解如何得到你想要的。