代码之家  ›  专栏  ›  技术社区  ›  Royi

使用AVX实现指数函数的最快速度

  •  13
  • Royi  · 技术社区  · 7 年前

    我正在寻找一种在AVX元素(单精度浮点)上运行的指数函数的有效(快速)近似值。即- __m256 _mm256_exp_ps( __m256 x ) 没有SVML。

    相对精度应该大约为1e-6或20尾数位(2^20中的1部分)。

    如果它是用C风格和Intel Intrinsic编写的,我会很高兴。
    代码应该是可移植的(Windows、macOS、Linux、MSVC、ICC、GCC等)。


    这类似于 Fastest Implementation of Exponential Function Using SSE ,但这个问题的精度很低,但速度很快(目前的答案精度大约为1e-3)。

    此外,这个问题是寻找AVX/AVX2(和FMA)。但请注意,这两个问题的答案很容易在SSE4之间移植 __m128 或AVX2 __m256 ,因此未来的读者应根据所需的精度/性能权衡进行选择。

    5 回复  |  直到 7 年前
        1
  •  9
  •   wim    7 年前

    这个 exp 功能来自 avx_mathfun 使用范围缩减结合切比雪夫近似多项式来计算8 经验值 -与AVX指令并行。使用正确的编译器设置以确保 addps mulps 在可能的情况下,融合到FMA说明中。

    改编原作非常简单 经验值 代码来自 avx\U mathfun 可移植(跨不同编译器)C/AVX2内部代码。原始代码使用gcc样式的对齐属性和巧妙的宏 _mm256_set1_ps() 相反,它位于小测试代码和表的下面。修改后的代码需要AVX2。

    以下代码用于简单测试:

    int main(){
        int i;
        float xv[8];
        float yv[8];
        __m256 x = _mm256_setr_ps(1.0f, 2.0f, 3.0f ,4.0f ,5.0f, 6.0f, 7.0f, 8.0f);
        __m256 y = exp256_ps(x);
        _mm256_store_ps(xv,x);
        _mm256_store_ps(yv,y);
    
        for (i=0;i<8;i++){
            printf("i = %i, x = %e, y = %e \n",i,xv[i],yv[i]);
        }
        return 0;
    }
    

    输出似乎正常:

    i = 0, x = 1.000000e+00, y = 2.718282e+00 
    i = 1, x = 2.000000e+00, y = 7.389056e+00 
    i = 2, x = 3.000000e+00, y = 2.008554e+01 
    i = 3, x = 4.000000e+00, y = 5.459815e+01 
    i = 4, x = 5.000000e+00, y = 1.484132e+02 
    i = 5, x = 6.000000e+00, y = 4.034288e+02 
    i = 6, x = 7.000000e+00, y = 1.096633e+03 
    i = 7, x = 8.000000e+00, y = 2.980958e+03 
    

    修改后的代码(AVX2)为:

    #include <stdio.h>
    #include <immintrin.h>
    /*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell  expc.c    */
    
    __m256 exp256_ps(__m256 x) {
    /* Modified code. The original code is here: https://github.com/reyoung/avx_mathfun
    
       AVX implementation of exp
       Based on "sse_mathfun.h", by Julien Pommier
       http://gruntthepeon.free.fr/ssemath/
       Copyright (C) 2012 Giovanni Garberoglio
       Interdisciplinary Laboratory for Computational Science (LISC)
       Fondazione Bruno Kessler and University of Trento
       via Sommarive, 18
       I-38123 Trento (Italy)
      This software is provided 'as-is', without any express or implied
      warranty.  In no event will the authors be held liable for any damages
      arising from the use of this software.
      Permission is granted to anyone to use this software for any purpose,
      including commercial applications, and to alter it and redistribute it
      freely, subject to the following restrictions:
      1. The origin of this software must not be misrepresented; you must not
         claim that you wrote the original software. If you use this software
         in a product, an acknowledgment in the product documentation would be
         appreciated but is not required.
      2. Altered source versions must be plainly marked as such, and must not be
         misrepresented as being the original software.
      3. This notice may not be removed or altered from any source distribution.
      (this is the zlib license)
    */
    /* 
      To increase the compatibility across different compilers the original code is
      converted to plain AVX2 intrinsics code without ingenious macro's,
      gcc style alignment attributes etc. The modified code requires AVX2
    */
    __m256   exp_hi        = _mm256_set1_ps(88.3762626647949f);
    __m256   exp_lo        = _mm256_set1_ps(-88.3762626647949f);
    
    __m256   cephes_LOG2EF = _mm256_set1_ps(1.44269504088896341);
    __m256   cephes_exp_C1 = _mm256_set1_ps(0.693359375);
    __m256   cephes_exp_C2 = _mm256_set1_ps(-2.12194440e-4);
    
    __m256   cephes_exp_p0 = _mm256_set1_ps(1.9875691500E-4);
    __m256   cephes_exp_p1 = _mm256_set1_ps(1.3981999507E-3);
    __m256   cephes_exp_p2 = _mm256_set1_ps(8.3334519073E-3);
    __m256   cephes_exp_p3 = _mm256_set1_ps(4.1665795894E-2);
    __m256   cephes_exp_p4 = _mm256_set1_ps(1.6666665459E-1);
    __m256   cephes_exp_p5 = _mm256_set1_ps(5.0000001201E-1);
    __m256   tmp           = _mm256_setzero_ps(), fx;
    __m256i  imm0;
    __m256   one           = _mm256_set1_ps(1.0f);
    
            x     = _mm256_min_ps(x, exp_hi);
            x     = _mm256_max_ps(x, exp_lo);
    
      /* express exp(x) as exp(g + n*log(2)) */
            fx    = _mm256_mul_ps(x, cephes_LOG2EF);
            fx    = _mm256_add_ps(fx, _mm256_set1_ps(0.5f));
            tmp   = _mm256_floor_ps(fx);
    __m256  mask  = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);    
            mask  = _mm256_and_ps(mask, one);
            fx    = _mm256_sub_ps(tmp, mask);
            tmp   = _mm256_mul_ps(fx, cephes_exp_C1);
    __m256  z     = _mm256_mul_ps(fx, cephes_exp_C2);
            x     = _mm256_sub_ps(x, tmp);
            x     = _mm256_sub_ps(x, z);
            z     = _mm256_mul_ps(x,x);
    
    __m256  y     = cephes_exp_p0;
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p1);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p2);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p3);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p4);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p5);
            y     = _mm256_mul_ps(y, z);
            y     = _mm256_add_ps(y, x);
            y     = _mm256_add_ps(y, one);
    
      /* build 2^n */
            imm0  = _mm256_cvttps_epi32(fx);
            imm0  = _mm256_add_epi32(imm0, _mm256_set1_epi32(0x7f));
            imm0  = _mm256_slli_epi32(imm0, 23);
    __m256  pow2n = _mm256_castsi256_ps(imm0);
            y     = _mm256_mul_ps(y, pow2n);
            return y;
    }
    
    int main(){
        int i;
        float xv[8];
        float yv[8];
        __m256 x = _mm256_setr_ps(1.0f, 2.0f, 3.0f ,4.0f ,5.0f, 6.0f, 7.0f, 8.0f);
        __m256 y = exp256_ps(x);
        _mm256_store_ps(xv,x);
        _mm256_store_ps(yv,y);
    
        for (i=0;i<8;i++){
            printf("i = %i, x = %e, y = %e \n",i,xv[i],yv[i]);
        }
        return 0;
    }
    


    @Peter Cordes points out , 应该可以更换 _mm256_floor_ps(fx + 0.5f) 通过 _mm256_round_ps(fx) . 此外 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); 接下来的两行似乎是多余的。 通过组合可以进一步优化 cephes_exp_C1 cephes_exp_C2 进入 inv_LOG2EF . 这导致以下代码未经过彻底测试!
    #include <stdio.h>
    #include <immintrin.h>
    #include <math.h>
    /*    gcc -O3 -m64 -Wall -mavx2 -march=broadwell  expc.c -lm     */
    
    __m256 exp256_ps(__m256 x) {
    /* Modified code from this source: https://github.com/reyoung/avx_mathfun
    
       AVX implementation of exp
       Based on "sse_mathfun.h", by Julien Pommier
       http://gruntthepeon.free.fr/ssemath/
       Copyright (C) 2012 Giovanni Garberoglio
       Interdisciplinary Laboratory for Computational Science (LISC)
       Fondazione Bruno Kessler and University of Trento
       via Sommarive, 18
       I-38123 Trento (Italy)
      This software is provided 'as-is', without any express or implied
      warranty.  In no event will the authors be held liable for any damages
      arising from the use of this software.
      Permission is granted to anyone to use this software for any purpose,
      including commercial applications, and to alter it and redistribute it
      freely, subject to the following restrictions:
      1. The origin of this software must not be misrepresented; you must not
         claim that you wrote the original software. If you use this software
         in a product, an acknowledgment in the product documentation would be
         appreciated but is not required.
      2. Altered source versions must be plainly marked as such, and must not be
         misrepresented as being the original software.
      3. This notice may not be removed or altered from any source distribution.
      (this is the zlib license)
    
    */
    /* 
      To increase the compatibility across different compilers the original code is
      converted to plain AVX2 intrinsics code without ingenious macro's,
      gcc style alignment attributes etc.
      Moreover, the part "express exp(x) as exp(g + n*log(2))" has been significantly simplified.
      This modified code is not thoroughly tested!
    */
    
    
    __m256   exp_hi        = _mm256_set1_ps(88.3762626647949f);
    __m256   exp_lo        = _mm256_set1_ps(-88.3762626647949f);
    
    __m256   cephes_LOG2EF = _mm256_set1_ps(1.44269504088896341f);
    __m256   inv_LOG2EF    = _mm256_set1_ps(0.693147180559945f);
    
    __m256   cephes_exp_p0 = _mm256_set1_ps(1.9875691500E-4);
    __m256   cephes_exp_p1 = _mm256_set1_ps(1.3981999507E-3);
    __m256   cephes_exp_p2 = _mm256_set1_ps(8.3334519073E-3);
    __m256   cephes_exp_p3 = _mm256_set1_ps(4.1665795894E-2);
    __m256   cephes_exp_p4 = _mm256_set1_ps(1.6666665459E-1);
    __m256   cephes_exp_p5 = _mm256_set1_ps(5.0000001201E-1);
    __m256   fx;
    __m256i  imm0;
    __m256   one           = _mm256_set1_ps(1.0f);
    
            x     = _mm256_min_ps(x, exp_hi);
            x     = _mm256_max_ps(x, exp_lo);
    
      /* express exp(x) as exp(g + n*log(2)) */
            fx     = _mm256_mul_ps(x, cephes_LOG2EF);
            fx     = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
    __m256  z      = _mm256_mul_ps(fx, inv_LOG2EF);
            x      = _mm256_sub_ps(x, z);
            z      = _mm256_mul_ps(x,x);
    
    __m256  y      = cephes_exp_p0;
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p1);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p2);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p3);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p4);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p5);
            y      = _mm256_mul_ps(y, z);
            y      = _mm256_add_ps(y, x);
            y      = _mm256_add_ps(y, one);
    
      /* build 2^n */
            imm0   = _mm256_cvttps_epi32(fx);
            imm0   = _mm256_add_epi32(imm0, _mm256_set1_epi32(0x7f));
            imm0   = _mm256_slli_epi32(imm0, 23);
    __m256  pow2n  = _mm256_castsi256_ps(imm0);
            y      = _mm256_mul_ps(y, pow2n);
            return y;
    }
    
    int main(){
        int i;
        float xv[8];
        float yv[8];
        __m256 x = _mm256_setr_ps(11.0f, -12.0f, 13.0f ,-14.0f ,15.0f, -16.0f, 17.0f, -18.0f);
        __m256 y = exp256_ps(x);
        _mm256_store_ps(xv,x);
        _mm256_store_ps(yv,y);
    
     /* compare exp256_ps with the double precision exp from math.h, 
        print the relative error             */
        printf("i      x                     y = exp256_ps(x)      double precision exp        relative error\n\n");
        for (i=0;i<8;i++){ 
            printf("i = %i  x =%16.9e   y =%16.9e   exp_dbl =%16.9e   rel_err =%16.9e\n",
               i,xv[i],yv[i],exp((double)(xv[i])),
               ((double)(yv[i])-exp((double)(xv[i])))/exp((double)(xv[i])) );
        }
        return 0;
    }
    

    下表通过将exp256\U ps与双精度进行比较,给出了某些点的精度印象 经验值 从…起 math.h . 相对误差在最后一列。

    i      x                     y = exp256_ps(x)      double precision exp        relative error
    
    i = 0  x = 1.000000000e+00   y = 2.718281746e+00   exp_dbl = 2.718281828e+00   rel_err =-3.036785947e-08
    i = 1  x =-2.000000000e+00   y = 1.353352815e-01   exp_dbl = 1.353352832e-01   rel_err =-1.289636419e-08
    i = 2  x = 3.000000000e+00   y = 2.008553696e+01   exp_dbl = 2.008553692e+01   rel_err = 1.672817689e-09
    i = 3  x =-4.000000000e+00   y = 1.831563935e-02   exp_dbl = 1.831563889e-02   rel_err = 2.501162103e-08
    i = 4  x = 5.000000000e+00   y = 1.484131622e+02   exp_dbl = 1.484131591e+02   rel_err = 2.108215155e-08
    i = 5  x =-6.000000000e+00   y = 2.478752285e-03   exp_dbl = 2.478752177e-03   rel_err = 4.380257261e-08
    i = 6  x = 7.000000000e+00   y = 1.096633179e+03   exp_dbl = 1.096633158e+03   rel_err = 1.849522682e-08
    i = 7  x =-8.000000000e+00   y = 3.354626242e-04   exp_dbl = 3.354626279e-04   rel_err =-1.101575118e-08
    
        2
  •  9
  •   njuffa    7 年前

    由于快速计算 exp() 需要操纵IEEE-754浮点操作数的指数字段, AVX 不太适合此计算,因为它缺少整数运算。因此,我将重点关注 AVX2 . 对融合乘法加法的支持在技术上是与 AVX2 因此,我提供了两条代码路径,使用和不使用FMA,由宏控制 USE_FMA .

    下面的代码计算 exp() 几乎 所需精度为10 -6 . FMA的使用在这里并没有提供任何显著的改进,但它应该在支持它的平台上提供性能优势。

    前一个 answer 对于精度较低的SSE实现,无法完全扩展到相当精确的实现,因为它包含一些数值特性较差的计算,但在这种情况下并不重要。而不是计算e x个 = 2 * 2 f 具有 f [0,1]或 f 在[-,]中,计算e是有利的 x个 = 2 *e类 f 具有 f 在较窄的间隔[-log 2,log 2],其中 log 表示自然对数。

    为此,我们首先计算 i = rint(x * log2(e)) 然后 f = x - log(2) * i . 重要的是 ,需要采用后一种计算方法 较高的 比本机精度更高,以提供准确的简化参数,并将其传递给核心近似值。为此,我们使用了Cody Waite方案,该方案首次发表在W.J.Cody&W.Waite,“基本功能的软件手册”,普伦蒂斯·霍尔,1980年。常数对数(2)分为较大幅值的“高”部分和较小幅值的“低”部分,这两部分保持了“高”部分和数学常数之间的差异。

    选择尾数中有足够尾数零位的高部分,以便 i “高”部分为 确切地 以本机精度表示。在这里,我选择了一个包含八个尾随零位的“高”部分,如下所示: 当然可以放入八位。

    本质上,我们计算f=x-i*log(2) 高的 -i*日志(2) 低的 . 这个简化的参数被传递到核心近似,它是一个多项式 minimax approximation ,结果按2缩放 正如前面的回答一样。

    #include <immintrin.h>
    
    #define USE_FMA 0
    
    /* compute exp(x) for x in [-87.33654f, 88.72283] 
       maximum relative error: 3.1575e-6 (USE_FMA = 0); 3.1533e-6 (USE_FMA = 1)
    */
    __m256 faster_more_accurate_exp_avx2 (__m256 x)
    {
        __m256 t, f, p, r;
        __m256i i, j;
    
        const __m256 l2e = _mm256_set1_ps (1.442695041f); /* log2(e) */
        const __m256 l2h = _mm256_set1_ps (-6.93145752e-1f); /* -log(2)_hi */
        const __m256 l2l = _mm256_set1_ps (-1.42860677e-6f); /* -log(2)_lo */
        /* coefficients for core approximation to exp() in [-log(2)/2, log(2)/2] */
        const __m256 c0 =  _mm256_set1_ps (0.041944388f);
        const __m256 c1 =  _mm256_set1_ps (0.168006673f);
        const __m256 c2 =  _mm256_set1_ps (0.499999940f);
        const __m256 c3 =  _mm256_set1_ps (0.999956906f);
        const __m256 c4 =  _mm256_set1_ps (0.999999642f);
    
        /* exp(x) = 2^i * e^f; i = rint (log2(e) * x), f = x - log(2) * i */
        t = _mm256_mul_ps (x, l2e);      /* t = log2(e) * x */
        r = _mm256_round_ps (t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); /* r = rint (t) */
    
    #if USE_FMA
        f = _mm256_fmadd_ps (r, l2h, x); /* x - log(2)_hi * r */
        f = _mm256_fmadd_ps (r, l2l, f); /* f = x - log(2)_hi * r - log(2)_lo * r */
    #else // USE_FMA
        p = _mm256_mul_ps (r, l2h);      /* log(2)_hi * r */
        f = _mm256_add_ps (x, p);        /* x - log(2)_hi * r */
        p = _mm256_mul_ps (r, l2l);      /* log(2)_lo * r */
        f = _mm256_add_ps (f, p);        /* f = x - log(2)_hi * r - log(2)_lo * r */
    #endif // USE_FMA
    
        i = _mm256_cvtps_epi32(t);       /* i = (int)rint(t) */
    
        /* p ~= exp (f), -log(2)/2 <= f <= log(2)/2 */
        p = c0;                          /* c0 */
    #if USE_FMA
        p = _mm256_fmadd_ps (p, f, c1);  /* c0*f+c1 */
        p = _mm256_fmadd_ps (p, f, c2);  /* (c0*f+c1)*f+c2 */
        p = _mm256_fmadd_ps (p, f, c3);  /* ((c0*f+c1)*f+c2)*f+c3 */
        p = _mm256_fmadd_ps (p, f, c4);  /* (((c0*f+c1)*f+c2)*f+c3)*f+c4 ~= exp(f) */
    #else // USE_FMA
        p = _mm256_mul_ps (p, f);        /* c0*f */
        p = _mm256_add_ps (p, c1);       /* c0*f+c1 */
        p = _mm256_mul_ps (p, f);        /* (c0*f+c1)*f */
        p = _mm256_add_ps (p, c2);       /* (c0*f+c1)*f+c2 */
        p = _mm256_mul_ps (p, f);        /* ((c0*f+c1)*f+c2)*f */
        p = _mm256_add_ps (p, c3);       /* ((c0*f+c1)*f+c2)*f+c3 */
        p = _mm256_mul_ps (p, f);        /* (((c0*f+c1)*f+c2)*f+c3)*f */
        p = _mm256_add_ps (p, c4);       /* (((c0*f+c1)*f+c2)*f+c3)*f+c4 ~= exp(f) */
    #endif // USE_FMA
    
        /* exp(x) = 2^i * p */
        j = _mm256_slli_epi32 (i, 23); /* i << 23 */
        r = _mm256_castsi256_ps (_mm256_add_epi32 (j, _mm256_castps_si256 (p))); /* r = p * 2^i */
    
        return r;
    }
    

    如果需要更高的精度,可以使用以下一组系数将多项式近似度提高1:

    /* maximum relative error: 1.7428e-7 (USE_FMA = 0); 1.6586e-7 (USE_FMA = 1) */
    const __m256 c0 =  _mm256_set1_ps (0.008301110f);
    const __m256 c1 =  _mm256_set1_ps (0.041906696f);
    const __m256 c2 =  _mm256_set1_ps (0.166674897f);
    const __m256 c3 =  _mm256_set1_ps (0.499990642f);
    const __m256 c4 =  _mm256_set1_ps (0.999999762f);
    const __m256 c5 =  _mm256_set1_ps (1.000000000f);
    
        3
  •  5
  •   jenkas    5 年前

    我对这个做了很多研究,发现了这个,它的相对精度约为1-07e,并且很容易转换为向量指令。 只有4个常量、5个乘法和1个除法,速度是内置exp()函数的两倍。

    float fast_exp(float x)
    {
        const float c1 = 0.007972914726F;
        const float c2 = 0.1385283768F;
        const float c3 = 2.885390043F;
        const float c4 = 1.442695022F;      
        x *= c4; //convert to 2^(x)
        int intPart = (int)x;
        x -= intPart;
        float xx = x * x;
        float a = x + c1 * xx * x;
        float b = c3 + c2 * xx;
        float res = (b + a) / (b - a);
        reinterpret_cast<int &>(res) += intPart << 23; // res *= 2^(intPart)
        return res;
    }
    

    转换为AVX(已更新)

    __m256 _mm256_exp_ps(__m256 _x)
    {
        __m256 c1 = _mm256_set1_ps(0.007972914726F);
        __m256 c2 = _mm256_set1_ps(0.1385283768F);
        __m256 c3 = _mm256_set1_ps(2.885390043F);
        __m256 c4 = _mm256_set1_ps(1.442695022F);
        __m256 x = _mm256_mul_ps(_x, c4); //convert to 2^(x)
        __m256 intPartf = _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
        x = _mm256_sub_ps(x, intPartf);
        __m256 xx = _mm256_mul_ps(x, x);
        __m256 a = _mm256_add_ps(x, _mm256_mul_ps(c1, _mm256_mul_ps(xx, x))); //can be improved with FMA
        __m256 b = _mm256_add_ps(c3, _mm256_mul_ps(c2, xx));
        __m256 res = _mm256_div_ps(_mm256_add_ps(b, a), _mm256_sub_ps(b, a));
        __m256i intPart = _mm256_cvtps_epi32(intPartf); //res = 2^intPart. Can be improved with AVX2!
        __m128i ii0 = _mm_slli_epi32(_mm256_castsi256_si128(intPart), 23);
        __m128i ii1 = _mm_slli_epi32(_mm256_extractf128_si256(intPart, 1), 23);     
        __m128i res_0 = _mm_add_epi32(ii0, _mm256_castsi256_si128(_mm256_castps_si256(res)));
        __m128i res_1 = _mm_add_epi32(ii1, _mm256_extractf128_si256(_mm256_castps_si256(res), 1));
        return _mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(res_0)), _mm_castsi128_ps(res_1), 1);
    }
    
        4
  •  0
  •   Peter Cordes    7 年前

    你可以 approximate the exponent yourself with Taylor series :

    exp(z) = 1 + z + pow(z,2)/2 + pow(z,3)/6 + pow(z,4)/24 + ...
    

    为此,您只需要AVX的加法和乘法运算。如果硬编码,然后乘以而不是除以,则1/2、1/6、1/24等系数的速度更快。

    根据您的精度要求,获取尽可能多的序列成员。请注意,您将得到相对误差:对于较小的 z 可能是吧 1e-6 绝对值,但很大 z 这将超过 1e-6 绝对的,静止的 abs(E-E1)/abs(E) - 1 小于 1e-6 (其中 E 是精确的指数 E1 是近似值)。

    更新:正如@PeterCordes在评论中提到的,可以通过分离整数和小数部分的指数,通过操纵二进制的指数字段来处理整数部分来提高精度 float 表示(基于2^x,而不是e^x)。然后,泰勒级数只需在小范围内最小化误差。

        5
  •  0
  •   huseyin tugrul buyukisik    2 年前

    对于归一化输入([-1,1]),可以使用多项式近似:

        // compute Simd exp() at a time (only optimized for Type=float)
        template<typename Type, int Simd>
        inline
        void expFast(float * const __restrict__ data, float * const __restrict__ result) noexcept
        {
    
            alignas(64)
            Type resultData[Simd];
    
            
            for(int i=0;i<Simd;i++)
            {
                resultData[i] =    Type(0.0001972591916103993980868836)*data[i] + Type(0.001433947376170863208244555);
            }
    
            
            for(int i=0;i<Simd;i++)
            {
                resultData[i] =    resultData[i]*data[i] + Type(0.008338950118885968265658448);
            }
    
            
            for(int i=0;i<Simd;i++)
            {
                resultData[i] =    resultData[i]*data[i] + Type(0.04164162895364054151059463);
            }
    
            
            for(int i=0;i<Simd;i++)
            {
                resultData[i] =    resultData[i]*data[i] + Type(0.1666645212581130408580066);
            }
    
            
            for(int i=0;i<Simd;i++)
            {
                resultData[i] =    resultData[i]*data[i] + Type(0.5000045184212300597437206);
            }
    
            
            for(int i=0;i<Simd;i++)
            {
                resultData[i] =    resultData[i]*data[i] + Type(0.9999999756072401879691824);
            }
    
            
            for(int i=0;i<Simd;i++)
            {
                result[i] =    resultData[i]*data[i] + Type(0.999999818912344906607359);
            }
    
        }
    

    对于在-1和1之间拾取的6400万个点,其平均误差为0.5 ULPS,最大误差为10 ULPS。与AVX1(推土机)上的std::exp相比,它有10倍的加速比。

    我认为可以将此函数与整数乘法结合起来,以支持所有幂。但是简单乘法部分需要是O(logN),而不是O(N),这样才能足够快地进行大的幂运算。例如,如果计算x^10,则只需对其本身进行1次额外运算即可得到x^20,而不是通过与x相乘进行10次额外运算。

    在循环中使用时,编译器会生成以下内容:

    .L2:
        vmovaps zmm1, ZMMWORD PTR [rax]
        add     rax, 64
        vmovaps zmm0, zmm1
        vfmadd132ps     zmm0, zmm8, zmm9
        vfmadd132ps     zmm0, zmm7, zmm1
        vfmadd132ps     zmm0, zmm6, zmm1
        vfmadd132ps     zmm0, zmm5, zmm1
        vfmadd132ps     zmm0, zmm4, zmm1
        vfmadd132ps     zmm0, zmm3, zmm1
        vfmadd132ps     zmm0, zmm2, zmm1
        vmovaps ZMMWORD PTR [rax-64], zmm0
        cmp     rax, rdx
        jne     .L2
    

    我认为它足够快,可以节省一些周期来处理输入的整数幂,可能达到浮点的极限(10^38)。