【问题标题】:TensorFlow unstack with numpyTensorFlow unstack 与 numpy
【发布时间】:2019-12-30 22:32:57
【问题描述】:

我正在将代码从 tensorflow 移植到 numpy,但是这行代码遇到了问题:

tensor_unstack = tf.unstack(some_tensor, axis=0)

使用了tf.unstack 方法,我无法在 numpy.所以我的问题是在使用 numpy 时如何实现 tf.unstack?

【问题讨论】:

    标签: python numpy tensorflow matrix data-science


    【解决方案1】:

    star operator 可用于取消堆叠 numpy 数组。这是一个例子:

    import tensorflow as tf
    import numpy as np
    
    a = np.array([1, 2, 3])
    b = np.array([4, 5, 6])
    c = np.stack([a, b])
    *d, = c
    print(d)
    
    c_ = tf.stack([a, b])
    d_ = tf.unstack(c_)
    
    with tf.Session() as sess:
        print(sess.run(d_))
    

    【讨论】:

      【解决方案2】:

      上面的答案不允许指定分割轴或想要进行的分割数。

      值得庆幸的是,内置的 numpy 函数提供了更好的解决方案。查看numpy.split 及其专用版本numpy.hsplitnumpy.vsplitnumpy.dsplitnumpy.array_split

      import numpy
      a = numpy.arange(9).reshape(3,3)
      
      # makes 3 equal splits along axis 0. equivalent to numpy.vsplit
      print(numpy.split(a, 3, axis=0))
      
      # 3 equal splits along axis 1. equivalent to numpy.hsplit
      print(numpy.split(a, 3, axis=1))
      

      【讨论】:

      • 然而,这将保留“未堆叠”维度:如果x.shape: (4, 2, 3),则x1, x2 = np.hsplit(x,2) 将产生x1.shape: (4, 1, 3)。可以使用squeeze 删除未堆叠的维度:x1, x2 = [ xx.squeeze() for xx in np.hsplit(x,2)],现在是x1.shape: (4, 3)
      猜你喜欢
      • 1970-01-01
      • 2017-07-30
      • 1970-01-01
      • 1970-01-01
      • 2021-06-12
      • 2017-06-30
      • 1970-01-01
      • 1970-01-01
      • 2018-06-23
      相关资源
      最近更新 更多