【发布时间】:2021-06-12 01:40:33
【问题描述】:
我正在尝试用 C++ 构建一个张量类。这仅适用于个人项目以在 C++ 中进行一些练习,并且它有点工作,但现在我遇到了一些我不太理解的 C++ 问题。这是我的张量类的主要结构,省略了一些不相关的函数。
#include <vector>
template<typename T> struct Tensor {
// Support for up to 5 dimensions
T& at(std::vector<std::size_t> indices); // Calls the other at() functions depending on the vector size
T& at(std::size_t d1);
T& at(std::size_t d1, std::size_t d2);
T& at(std::size_t d1, std::size_t d2, std::size_t d3);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4, std::size_t d5);
Tensor transpose();
Tensor matmul(Tensor& rhs) {
rhs.at(0);
return Tensor();
}
};
在测试 matmul() 以将两个张量相乘时,以下代码有效:
void works() {
Tensor<float> tensor1;
Tensor<float> tensor2 = tensor1.transpose();
Tensor<float> tensor3 = tensor1.matmul(tensor2);
}
但是,以下代码我没有明确创建 tensor1 作为 tensor1 的转置失败:
void fails() {
Tensor<float> tensor1;
Tensor<float> tensor = tensor1.matmul(tensor1.transpose());
}
抛出的第一个错误是
tensor.test.cpp:26:60: error: cannot bind non-const lvalue reference of type 'Tensor<float>&' to an rvalue of type 'Tensor<float>'
26 | Tensor<float> tensor = tensor1.matmul(tensor1.transpose());
| ~~~~~~~~~~~~~~~~~^~
根据我的谷歌搜索和对 C++ 的有限理解,我尝试将 matmul() 的定义更改为
Tensor matmul(const Tensor& rhs) { // <-- const added
但是,如果我这样做,我会得到一个不同的错误:
tensor.hpp: In instantiation of 'Tensor<T> Tensor<T>::matmul(const Tensor<T>&) [with T = float]':
tensor.test.cpp:21:43: required from here
tensor.hpp:13:15: error: passing 'const Tensor<float>' as 'this' argument discards qualifiers [-fpermissive]
13 | rhs.at(0);
| ~~~~~~^~~
编辑 1: 在 cmets 之后,使用 Tensor matmul(const Tensor& rhs); 添加所有 const 定义的所有 at() 函数解决了我的问题。但是现在代码看起来像这样:
T& at(std::vector<std::size_t> indices);
T& at(std::size_t d1);
T& at(std::size_t d1, std::size_t d2);
T& at(std::size_t d1, std::size_t d2, std::size_t d3);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4, std::size_t d5);
const T& at(std::vector<std::size_t> indices) const;
const T& at(std::size_t d1) const;
const T& at(std::size_t d1, std::size_t d2) const;
const T& at(std::size_t d1, std::size_t d2, std::size_t d3) const;
const T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4) const;
const T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4, std::size_t d5) const;
这意味着相当多的重复代码。我想知道这是否可以改进。
【问题讨论】:
标签: c++ c++17 pass-by-reference