【问题标题】:mxnet (mshadow) getting the shape of a tensormxnet (mshadow) 获取张量的形状
【发布时间】:2017-06-16 06:41:53
【问题描述】:

我是 mshadow 的新手,我不明白为什么我会从以下代码 sn-p 中得到这些输出:

TensorContainer<cpu, 2> lhs(Shape2(2, 3));
lhs = 1.0;
printf("%u %u\n", lhs.size(0), lhs.size(1));
printf("%u %u\n", lhs[0].shape_[0], lhs[0].shape_[1]);
printf("%u %u\n", lhs[0].size(0), lhs[0].size(1));

输出是:

2 3
3 4
3 3

为什么第二个和第三个输出这些数字?因为lhs[0]是一维的,所以我认为它们应该是完全一样的,即3 0。谁能告诉我我错在哪里?提前致谢!

【问题讨论】:

    标签: shape mxnet tensor


    【解决方案1】:

    你是对的,张量 lhs[0] 是一维的,但首先要回答你的问题,让我展示一下幕后发生的事情。 TensorContainer 不会覆盖 [] 运算符,而是使用来自父级(即 Tensor)的运算符,更准确地说是调用 following one

      MSHADOW_XINLINE Tensor<Device, kSubdim, DType> operator[](index_t idx) const {
        return Tensor<Device, kSubdim, DType>(dptr_ + this->MemSize<1>() * idx,
                                              shape_.SubShape(), stride_, stream_);
      }
    

    可以看出,它在堆栈上创建了一个新的张量。虽然对于大多数情况,它会创建一个通用的N-dimensional Tensor,但对于一维情况,它会创建一个特殊的1-dimensional Tensor

    现在,当我们确定了运算符 [] 究竟返回了什么后,让我们看看该类的字段:

      DType *dptr_;
      Shape<1> shape_;
      index_t stride_;
    

    可以看出这里的shape_只有一维!所以没有 shape_1,而是通过调用 shape_1 它将返回 stride_(或它的一部分)。这是对 Tensor 构造函数的修改,您可以尝试运行它并查看那里实际发生了什么:

      MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape,
                             index_t stride, Stream<Device> *stream)
          : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {
         std::cout << "shape[0]: " << shape[0] << std::endl; // 3
         std::cout << "shape[1]: " << shape[1] << std::endl; // 0, as expected
         std::cout << "_shape[0]: " << shape_[0] << std::endl; // 3, as expected
         std::cout << "_shape[1]: " << shape_[1] << std::endl; // garbage (4)
         std::cout << "address of _shape[1]: " << &(shape_[1]) << std::endl;
         std::cout << "address of stride: " << &(stride_) << std::endl;
      }
    

    和输出:

    shape[0]: 3
    shape[1]: 0
    _shape[0]: 3
    _shape[1]: 4
    address of _shape[1]: 0x7fffa28ec44c
    address of stride: 0x7fffa28ec44c
    

    _shape1 和 stride 的地址相同(0x7fffa28ec44c)。

    【讨论】:

    • 绝妙的答案!非常感谢!
    • @ROBOTAI 如果您认为这是正确的答案,请问您将其标记为正确吗?
    猜你喜欢
    • 1970-01-01
    • 2019-01-30
    • 2023-01-13
    • 2019-09-01
    • 2018-12-13
    • 2022-01-03
    • 2016-12-12
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多