【问题标题】:How to overwrite operator in C++ class with a variadic function?如何用可变参数函数覆盖 C++ 类中的运算符?
【发布时间】:2021-05-31 09:07:59
【问题描述】:

这里的C++新手:我想创建一个模板类来创建不同数据类型和d维度的张量,其中d由一个形状指定。例如,形状为(2, 3, 5) 的张量有 3 个维度,包含 24 个元素。我使用一维向量存储所有数据元素,并希望使用形状信息访问元素以查找元素。

我想覆盖() 运算符来访问元素。由于维度可以变化,() 运算符的输入参数数量也可以变化。从技术上讲,我可以使用向量作为输入参数,但 C++ 似乎也支持可变参数函数。但是,我无法理解它。

到目前为止我所拥有的:

#ifndef TENSOR_HPP
#define TENSOR_HPP

#include <vector>
#include <numeric>
#include <algorithm>
#include <stdexcept>
#include <iostream>
#include <stdarg.h>


template <typename T> class Tensor {

    private:
        std::vector<T> m_data;
        std::vector<std::size_t> m_shape;
        std::size_t m_size;
        
    public:
        // Constructors
        Tensor(std::vector<T> data, std::vector<std::size_t> shape);

        // Destructor
        ~Tensor();

        // Access the individual elements                                                                                                                                                                                               
        T& operator()(std::size_t&... d_args);
        
};


template <typename T> Tensor<T>::Tensor(std::vector<T> data, std::vector<std::size_t> shape) {
    // Calculate number of data values based on shape
    m_size = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<std::size_t>());
    // Check if calculated number of values match the actual number
    if (data.size() != m_size) {
        throw std::length_error("Tensor shape does not match the number of data values");
    } 
    // All good from here
    m_data = data;
    m_shape = shape;
}

template <typename T> T& Tensor<T>::operator() (std::size_t&... d_args) {
    // Return something to avoid warning
    return m_data[0];
};

template <typename T> Tensor<T>::~Tensor() {
    //delete[] m_values;
};


#endif

当我执行以下操作时不会:

std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
std::vector<std::size_t> shape = {2, 3, 4};
Tensor<float> tensor(data, shape);

tensor(2,0,3); // <-- What I would like to do

// Possible workaround with vector which I would like to avoid
// std::vector<std::size_t> index = {2,0,3};
// tensor(index);

我得到错误:

tensor2.hpp:27:33: error: expansion pattern ‘std::size_t&’ {aka ‘long unsigned int&’} contains no parameter packs

使用可变参数函数覆盖() 运算符的正确方法是什么?

【问题讨论】:

  • Tensor&lt;float, 2, 3, 4&gt; 怎么样?
  • 您的编译器不应允许您将纯右值作为具有非 const 限定引用类型的参数传递...
  • @Jarod42 以这种方式定义张量当然没问题。这将如何影响代码。同样,这应该是灵活的,因为我可能希望将 24 个值存储在 Tensor&lt;float, 2, 12&gt;Tensor&lt;float, 12, 2&gt;Tensor&lt;float, 1, 1, 1, 1, 24&gt;Tensor&lt;float, 1, 12, 1, 2&gt; 等中。假设这里会有一个变化template &lt;typename T, ???&gt; class Tensor() 运算符的定义是什么样的?
  • 术语:你想重载一个操作符,而不是覆盖(后者是一个非术语)。
  • 如果形状是 {2,3,4} 并且您尝试访问 tensor(0,0),您希望发生什么?

标签: c++ variadic-templates variadic-functions


【解决方案1】:

通过提供“形状”作为模板参数,您可以这样做:

// Helper for folding to specific type
template <std::size_t, typename T> using always_type = T;

// Your Tensor class
template <typename T, std::size_t... Dims>
class MultiArray
{
public:

    explicit MultiArray(std::vector<T> data) : values(std::move(data))
    {
        assert(values.size() == (1 * ... * Dims));
    }

    const T& get(const std::array<std::size_t, sizeof...(Dims)>& indexes) const
    {
        return values[computeIndex(indexes)];
    }
    T& get(const std::array<std::size_t, sizeof...(Dims)>& indexes)
    {
        return values[computeIndex(indexes)];
    }

    const T& get(always_type<Dims, std::size_t>... indexes) const
    {
        return get({{indexes...}});
    }
    T& get(always_type<Dims, std::size_t>... indexes)
    {
        return get({{indexes...}});
    }

    static std::size_t computeIndex(const std::array<std::size_t, sizeof...(Dims)>& indexes)
    {
        constexpr std::array<std::size_t, sizeof...(Dims)> dimensions{{Dims...}};
        size_t index = 0;
        size_t mul = 1;

        for (size_t i = dimensions.size(); i != 0; --i) {
            assert(indexes[i - 1] < dimensions[i - 1]);
            index += indexes[i - 1] * mul;
            mul *= dimensions[i - 1];
        }
        assert(index < (1 * ... * Dims));
        return index;
    }

    static std::array<std::size_t, sizeof...(Dims)> computeIndexes(std::size_t index)
    {
        assert(index < (1 * ... * Dims));

        constexpr std::array<std::size_t, sizeof...(Dims)> dimensions{{Dims...}};
        std::array<std::size_t, sizeof...(Dims)> res;

        std::size_t mul = (1 * ... * Dims);
        for (std::size_t i = 0; i != dimensions.size(); ++i) {
            mul /= dimensions[i];
            res[i] = index / mul;
            assert(res[i] < dimensions[i]);
            index -= res[i] * mul;
        }
        return res;
    }

private:
    std::vector<T> values; // possibly: std::array<T, (1 * ... * Dims)>
};

用法类似于

std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
MultiArray<float, 2, 3, 4> tensor(data);
std::cout << tensor.get(1, 0, 3); // 16

Demo

【讨论】:

  • 好的,哇...这要消化很多:)。我已经有一个适用于矩阵的解决方案(即 2d 传感器;其中() 运算符总是正好得到 2 个参数),我“只”想用任意维数推广这个解决方案张量。我想我在这里有点天真。以我目前(和有限)的 C/C++ 知识,我什至无法正确阅读您的代码。
  • ComputeIndex(es) 是将线性索引转换为多索引的“算法”部分,反之亦然。然后我得到get 将索引作为数组(2 个版本 const/non const)和你想要使用 always_type 摆脱数组语法的那个(并将其转发到数组语法)。
  • 是的,我可以遵循从线性索引到多维索引的原理映射;在纸上,我已经有了一个解决方案作为我的 Matrix/2d-Tensor 类的扩展。但是有很多语法我还不熟悉(例如,双花括号和省略号的所有用法)。但这只是我必须阅读的内容。
  • std::array 主要是一个带有 C 数组的结构,因此聚合初始化需要双花括号。 (C 数组有奇怪的语法,特别是供参考,std::array 有那个故障 :-/)。 Variadic 引入了可能不熟悉的省略号用法。如果您有任何问题,请不要犹豫。
  • 好吧,我想我已经了解了模板和参数包背后的基本概念。我认为是这两个概念的结合让我感到厌烦。使用https://cppinsights.io/ 进行查看,实例化的类和函数有所帮助。但是我可以正确地弄清楚我们在做什么模板别名。它真的需要吗?如果没有它,代码会如何变化?
【解决方案2】:

您可以添加具有尽可能多的重载的辅助函数来计算正确的索引来访问项目:

    T& getData(int dim1) { return m_data[dim1];}
    T& getData(int dim1, int dim2) { return m_data[ dim1* m_shape[1] + dim2 ];}
    T& getData(int dim1, int dim2, int dim3) { return m_data[ dim1*m_shape[1]*m_shape[2] + dim2*m_shape[2] + dim3 ];}

那么operator() 可能看起来像:

    template<class ... Args>                                                                                                                                                                                           
    T& operator()(Args... d_args) {
        static_assert( (std::is_integral_v<Args> && ...) ); // [1]
        return getData(d_args...);
    }

通过 [1] 我们限制 () 仅用于整数类型。

Live demo

【讨论】:

  • 此代码仅在没有static_assert 行的情况下运行,因为我当前安装了 gcc 10.2。我只是想知道,以这种方式定义一组辅助函数是一种很好的通用方法。可以说,在实践中我不需要超过 10 维的张量,但在哪里放置最大维数并因此帮助函数并不明显。
  • 可以通过-std=c++17开启c++17 in 10.2。
  • 啊,我明白了……那行得通!我不知道我必须明确打开它。只看到here说C++17应该是10.2支持的,一脸懵。
猜你喜欢
  • 1970-01-01
  • 2011-08-15
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2011-05-31
  • 1970-01-01
相关资源
最近更新 更多