【发布时间】:2017-07-06 15:15:20
【问题描述】:
我想检查批次的偶数和奇数元素,并在需要时交换它们。我设法得到了两个我想要交织的张量:
def tf_oplu(x, name=None):
even = x[:,::2] #slicing into odd and even parts on the batch
odd = x[:,1::2]
even_flatten = tf.reshape(even, [-1]) # flatten tensors
#in row-major order to apply function across them
odd_flatten = tf.reshape(odd, [-1])
compare = tf.to_float(even_flatten<odd_flatten)
compare_not = tf.to_float(even_flatten>=odd_flatten)
#def oplu(x,y): # trivial function
# if x<y : # (x<y)==1
# return y, x
# else:
# return x, y # (x<y)==0
even_flatten_new = odd_flatten * compare + even_flatten * compare_not
odd_flatten_new = odd_flatten * compare_not + even_flatten * compare
# convolute back
even_new = tf.reshape(even_flatten_new,[100,128])
odd_new = tf.reshape(odd_flatten_new,[100,128])
现在我想返回填充了偶数和奇数位的 $[100,256]$ 张量。在 numpy 中,我当然会这样做:
y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype)
y[:,0::2] = even_new
y[:,1::2] = odd_new
return y
但是对于 tensoflow 来说这样的事情是不可能的,因为张量是不可修改的。我想sparse tensor 或tf.gather_nd 都可以,但两者都需要生成索引数组,这对我来说又是一项不平凡的任务。
还有一点需要注意:我不想通过tf.py_func 使用任何python 函数,因为我检查了它们是否仅在CPU 上运行。也许 lambda 和 tf.map_fn 可能会有所帮助?谢谢!
【问题讨论】:
标签: python tensorflow