代码之家  ›  专栏  ›  技术社区  ›  tjwrona1992

将计算从运行时转移到编译时如何使我的代码变慢?

  •  2
  • tjwrona1992  · 技术社区  · 4 年前

    我正在编写一个函数,用Stockham FFT算法计算快速傅里叶变换,发现如果FFT的长度是2的幂,那么计算的所有“旋转因子”都可以在编译时预先计算。

    在FFT计算中,旋转因子的计算通常会占用总时间的很大一部分,因此从理论上讲,这样做应该会大大提高算法的性能。

    昨天我花了一整天的时间在一个新的编译器(gcc 10)上重新实现我的算法,这样我就可以使用C++20了 consteval 在编译时预先计算所有旋转因子的特性。我成功地做到了这一点,但最终,在编译时预先计算所有旋转因子的代码实际上运行得较慢!

    以下是在运行时执行所有计算的代码:

    #include <algorithm>
    #include <cassert>
    #include <chrono>
    #include <cmath>
    #include <complex>
    #include <iostream>
    #include <vector>
    
    using namespace std;
    
    static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x);
    
    constexpr bool IsPowerOf2(const size_t value)
    {
        return value && (!(value & (value - 1)));
    }
    
    vector<complex<double>> FFT(const vector<double>& x)
    {
        const auto N = x.size();
        assert(IsPowerOf2(x.size()));
        const auto NOver2 = N/2;
    
        vector<complex<double>> x_p(N);
        transform(x.begin(), x.end(), x_p.begin(), [](const double value) {
            return complex<double>(value);
            });
    
        return StockhamFFT(x_p);
    }
    
    // C++ implementation of the Stockam FFT algorithm
    static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x)
    {
        const auto N = x.size();
        assert(IsPowerOf2(N));
        const auto NOver2 = N/2;
    
        // Pre-calculate the twiddle factors (at runtime)
        vector<complex<double>> W(NOver2);
        const auto omega = 2.0 * M_PI / N;
        for (size_t n = 0; n < NOver2; ++n)
        {
            W[n] = complex{ cos(-omega * n), sin(-omega * n) };
        }
    
        // The Stockham algorithm requires one vector for input/output data and
        // another as a temporary workspace
        vector<complex<double>> a(x);
        vector<complex<double>> b(N);
    
        // Set the spacing between twiddle factors used at the first stage
        size_t WStride = N/2;
    
        // Loop through each stage of the FFT
        for (size_t stride = 1; stride < N; stride *= 2)
        {
            // Loop through the individual FFTs of each stage
            for (size_t m = 0; m < NOver2; m += stride)
            {
                const auto mTimes2 = m*2;
    
                // Perform each individual FFT
                for (size_t n = 0; n < stride; ++n)
                {
                    // Calculate the input indexes
                    const auto aIndex1 = n + m;
                    const auto aIndex2 = aIndex1 + NOver2;
    
                    // Calculate the output indexes
                    const auto bIndex1 = n + mTimes2;
                    const auto bIndex2 = bIndex1 + stride;
    
                    // Perform the FFT
                    const auto tmp1 = a[aIndex1];
                    const auto tmp2 = W[n*WStride]*a[aIndex2];
    
                    // Sum the results
                    b[bIndex1] = tmp1 + tmp2;
                    b[bIndex2] = tmp1 - tmp2; // (>*.*)> symmetry! <(*.*<)
                }
            }
    
            // Spacing between twiddle factors is half for the next stage
            WStride /= 2;
    
            // Swap the data (output of this stage is input of the next)
            a.swap(b);
        }
    
        return a;
    }
    
    int main()
    {
        size_t N = pow(2, 18);
        vector<double> x(N);
    
        int f_s = 1000;
        double t_s = 1.0 / f_s;
    
        for (size_t n = 0; n < N; ++n)
        {
            x[n] = sin(2 * M_PI * 100 * n * t_s);
        }
    
        auto sum = 0;
        for (int i = 1; i < 100; ++i)
        {
            auto start = chrono::high_resolution_clock::now();
            auto X = FFT(x);
            auto stop = chrono::high_resolution_clock::now();
            auto duration = chrono::duration_cast<chrono::microseconds>(stop - start);
            sum += duration.count();
        }
        auto average = sum / 100;
    
        std::cout << "duration " << average << " microseconds." << std::endl;
    }
    

    以此为起点,我能够从 StockhamFFT 函数,并在编译时使用 康斯特瓦尔 作用下面是代码的外观:

    #include <algorithm>
    #include <cassert>
    #include <chrono>
    #include <cmath>
    #include <complex>
    #include <iostream>
    #include <vector>
    
    using namespace std;
    
    static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x);
    
    constexpr bool IsPowerOf2(const size_t value)
    {
        return value && (!(value & (value - 1)));
    }
    
    // Evaluates FFT twiddle factors at compile time!
    template <size_t N>
    static consteval array<complex<double>, N/2> CalculateTwiddleFactors()
    {
        static_assert(IsPowerOf2(N), "N must be a power of 2.");
    
        array<complex<double>, N/2> W;
        const auto omega = 2.0*M_PI/N;
        for (size_t n = 0; n < N/2; ++n)
        {
            W[n] = complex{cos(-omega*n), sin(-omega*n)};
        }
    
        return W;
    }
    
    // Calculate the twiddle factors (>*O*)> AT COMPILE TIME <(*O*<)
    constexpr auto W = CalculateTwiddleFactors<static_cast<size_t>(pow(2,18))>();
    
    vector<complex<double>> FFT(const vector<double>& x)
    {
        const auto N = x.size();
        assert(IsPowerOf2(x.size()));
        const auto NOver2 = N/2;
    
        vector<complex<double>> x_p(N);
        transform(x.begin(), x.end(), x_p.begin(), [](const double value) {
            return complex<double>(value);
            });
    
        return StockhamFFT(x_p);
    }
    
    // C++ implementation of the Stockam FFT algorithm
    static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x)
    {
        const auto N = x.size();
        assert(IsPowerOf2(N));
        const auto NOver2 = N/2;
    
        //***********************************************************************
        // Twiddle factors are already calculated at compile time!!!
        //***********************************************************************
    
        // The Stockham algorithm requires one vector for input/output data and
        // another as a temporary workspace
        vector<complex<double>> a(x);
        vector<complex<double>> b(N);
    
        // Set the spacing between twiddle factors used at the first stage
        size_t WStride = N/2;
    
        // Loop through each stage of the FFT
        for (size_t stride = 1; stride < N; stride *= 2)
        {
            // Loop through the individual FFTs of each stage
            for (size_t m = 0; m < NOver2; m += stride)
            {
                const auto mTimes2 = m*2;
    
                // Perform each individual FFT
                for (size_t n = 0; n < stride; ++n)
                {
                    // Calculate the input indexes
                    const auto aIndex1 = n + m;
                    const auto aIndex2 = aIndex1 + NOver2;
    
                    // Calculate the output indexes
                    const auto bIndex1 = n + mTimes2;
                    const auto bIndex2 = bIndex1 + stride;
    
                    // Perform the FFT
                    const auto tmp1 = a[aIndex1];
                    const auto tmp2 = W[n*WStride]*a[aIndex2];
    
                    // Sum the results
                    b[bIndex1] = tmp1 + tmp2;
                    b[bIndex2] = tmp1 - tmp2; // (>*.*)> symmetry! <(*.*<)
                }
            }
    
            // Spacing between twiddle factors is half for the next stage
            WStride /= 2;
    
            // Swap the data (output of this stage is input of the next)
            a.swap(b);
        }
    
        return a;
    }
    
    int main()
    {
        size_t N = pow(2, 18);
        vector<double> x(N);
    
        int f_s = 1000;
        double t_s = 1.0 / f_s;
    
        for (size_t n = 0; n < N; ++n)
        {
            x[n] = sin(2 * M_PI * 100 * n * t_s);
        }
    
        auto sum = 0;
        for (int i = 1; i < 100; ++i)
        {
            auto start = chrono::high_resolution_clock::now();
            auto X = FFT(x);
            auto stop = chrono::high_resolution_clock::now();
            auto duration = chrono::duration_cast<chrono::microseconds>(stop - start);
            sum += duration.count();
        }
        auto average = sum / 100;
    
        std::cout << "duration " << average << " microseconds." << std::endl;
    }
    

    这两个版本都是在Ubuntu 19.10和gcc 10.0.1上编译的:

    g++ -std=c++2a -o main main.cpp
    

    请注意,gcc编译器是特别需要的,因为它是唯一支持 constexpr 版本 sin cos

    “运行时”示例产生以下结果:

    持续时间292854微秒。

    “编译时”示例产生以下结果:

    持续时间295230微秒。

    编译时版本确实需要更长的时间来编译,但仍然需要更长的时间来运行,即使大多数计算都是在程序开始之前完成的!这怎么可能?我是不是错过了一些关键的东西?

    0 回复  |  直到 4 年前