代码之家  ›  专栏  ›  技术社区  ›  R zu

用eigen3库编写一个以转置为参数的函数

  •  3
  • R zu  · 技术社区  · 6 年前

    1. 密集阵列/矩阵,或
    2. 密集数组/矩阵的转置。

    有没有可能避免使用完美转发?

    我试过使用 DenseBase 类型参数。但这不能接受矩阵的转置。

    当前解决方案:

    #include <Eigen/Eigen>
    #include <iostream>
    
    using namespace Eigen;
    
    template <typename U>
    auto f(U&& x) {
        auto x2 = std::forward<U>(x);
        auto max_x = x2.colwise().maxCoeff().eval();
        x2 = x2.rowwise() + max_x;
        return max_x;
    }
    
    int main() {
        Array<float, 3, 3> M1;
        M1 << 1, 2, 3, 
              4, 5, 6, 
              7, 8, 9;
        std::cout << M1 << "\n";
        // auto here might cause problem later ...
        // see eigen.tuxfamily.org/dox/TopicPitfalls.html
        auto max_x = f(M1.transpose());
        std::cout << M1 << "\n";
        std::cout << max_x << "\n";
    }
    

    结果:

    // original
    1 2 3
    4 5 6
    7 8 9
    // Increase each row by max of the row.
     4  5  6
    10 11 12
    16 17 18
    // Max of each row (not a column vector).
    3 6 9
    

    我试过了 EigenBase 使用以下行:

    template <typename U>
    auto f(EigenBase<U>& x) {
    ...
    

    test4.cpp:20:32: error: cannot bind non-const lvalue reference of type ‘Eigen::EigenBase<Eigen::Transpose<Eigen::Array<float, 3, 3> > >&’ to an rvalue of type ‘Eigen::EigenBase<Eigen::Transpose<Eigen::Array<float, 3, 3> > >’
         auto max_x = f(M1.transpose());
                        ~~~~~~~~~~~~^~
    
    2 回复  |  直到 6 年前
        1
  •  0
  •   R zu    6 年前

    用左值引用和右值引用参数重载函数似乎可以解决这个问题。

    我怀疑转置是唯一生成临时c++表达式的东西。如果我希望转置版本返回一个转置结果(列而不是行向量),而非转置版本返回一个非转置结果,这就不能解决这个问题。

    谢谢 kmdreko

    测试:

    #include <Eigen/Eigen>
    #include <iostream>
    
    using namespace Eigen;
    
    template <typename Derived>
    auto f(DenseBase<Derived>& x) {
        auto max_x = x.colwise().maxCoeff().eval();
        x = x.rowwise() + max_x;
        return max_x;
    }
    
    template <typename Derived>
    auto f(DenseBase<Derived>&& x) {
        return f(x).transpose().eval();
    }
    
    int main() {
        Array<float, 3, 3> M1, M2;
        M1 << 1, 2, 3, 
              4, 5, 6, 
              7, 8, 9;
        M2 = M1;
        std::cout << M1 << "\n";
    
        std::cout << "no transpose\n";
        Array<float, 3, 1> max_x = f(M1);
        std::cout << M1 << "\n";
        std::cout << max_x << "\n";
    
        std::cout << "transpose\n";
        Array<float, 1, 3> max_x2 = f(M2.transpose());
        std::cout << M2 << "\n";
        std::cout << max_x2 << "\n";
    }
    

    结果:

    1 2 3
    4 5 6
    7 8 9
    no transpose
     8 10 12
    11 13 15
    14 16 18
    // returns column vector
    7
    8
    9
    transpose
     4  5  6
    10 11 12
    16 17 18
    // returns row vector
    3 6 9
    
        2
  •  0
  •   R zu    6 年前

    #include <Eigen/Eigen>
    #include <iostream>
    
    using namespace Eigen;
    
    template <typename Derived>
    auto f(DenseBase<Derived>& x) {
        auto max_x = x.colwise().maxCoeff().eval();
        x = x.rowwise() + max_x;
        return max_x;
    }
    
    int main() {
        Array<float, 3, 3> M1, M2;
        M1 << 1, 2, 3, 
              4, 5, 6, 
              7, 8, 9;
        M2 = M1;
        std::cout << M1 << "\n";
    
        std::cout << "no transpose\n";
        Array<float, 3, 1> max_x = f(M1);
        std::cout << M1 << "\n";
        std::cout << max_x << "\n";
    
        std::cout << "transpose\n";
        auto m2_t = M2.transpose();
        Array<float, 1, 3> max_x2 = f(m2_t);
        std::cout << M2 << "\n";
        std::cout << max_x2 << "\n";
    }
    

    结果:

    1 2 3
    4 5 6
    7 8 9
    no transpose
     8 10 12
    11 13 15
    14 16 18
    7
    8
    9
    transpose
     4  5  6
    10 11 12
    16 17 18
    3 6 9