【发布时间】:2019-12-30 22:32:57
【问题描述】:
我正在将代码从 tensorflow 移植到 numpy,但是这行代码遇到了问题:
tensor_unstack = tf.unstack(some_tensor, axis=0)
使用了tf.unstack 方法,我无法在 numpy.所以我的问题是在使用 numpy 时如何实现 tf.unstack?
【问题讨论】:
标签: python numpy tensorflow matrix data-science
我正在将代码从 tensorflow 移植到 numpy,但是这行代码遇到了问题:
tensor_unstack = tf.unstack(some_tensor, axis=0)
使用了tf.unstack 方法,我无法在 numpy.所以我的问题是在使用 numpy 时如何实现 tf.unstack?
【问题讨论】:
标签: python numpy tensorflow matrix data-science
star operator 可用于取消堆叠 numpy 数组。这是一个例子:
import tensorflow as tf
import numpy as np
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
c = np.stack([a, b])
*d, = c
print(d)
c_ = tf.stack([a, b])
d_ = tf.unstack(c_)
with tf.Session() as sess:
print(sess.run(d_))
【讨论】:
上面的答案不允许指定分割轴或想要进行的分割数。
值得庆幸的是,内置的 numpy 函数提供了更好的解决方案。查看numpy.split 及其专用版本numpy.hsplit、numpy.vsplit、numpy.dsplit 和numpy.array_split。
import numpy
a = numpy.arange(9).reshape(3,3)
# makes 3 equal splits along axis 0. equivalent to numpy.vsplit
print(numpy.split(a, 3, axis=0))
# 3 equal splits along axis 1. equivalent to numpy.hsplit
print(numpy.split(a, 3, axis=1))
【讨论】:
x.shape: (4, 2, 3),则x1, x2 = np.hsplit(x,2) 将产生x1.shape: (4, 1, 3)。可以使用squeeze 删除未堆叠的维度:x1, x2 = [ xx.squeeze() for xx in np.hsplit(x,2)],现在是x1.shape: (4, 3)