代码之家  ›  专栏  ›  技术社区  ›  Zheyuan Li

为什么我的展开循环结尾部分的gcc代码生成器看起来很复杂?

  •  1
  • Zheyuan Li  · 技术社区  · 6 年前

    谢谢你到目前为止的评论。很抱歉,我在最初的问题中使用了一个糟糕的例子,几乎每个人都会说: “哦,你应该用 memcopy !” 但这不是我的问题。

    我的问题是 通用的 关于如何手动展开循环。考虑一下这个例子,通过对数组中的所有元素求和:

    #include <stdlib.h>
    
    double sum (size_t n, double *x) {
      size_t nr = n & 1;
      double *end = x + (n - nr);
      double sum_x = 0.0;
      for (; x < end; x++) sum_x += *x;
      if (nr) sum_x += *x;
      return sum_x;
      }
    

    编译器生成的程序集允许类似的行为(与我最初问题中的数组复制示例所示的行为相同)

    sum:
      movq %rdi, %rcx
      andl $1, %ecx
      subq %rcx, %rdi
      leaq (%rsi,%rdi,8), %rdx
      cmpq %rdx, %rsi
      jnb .L5
      movq %rsi, %rax
      pxor %xmm0, %xmm0
    .L3:
      addsd (%rax), %xmm0
      addq $8, %rax
      cmpq %rax, %rdx
      ja .L3
      movq %rsi, %rax
      notq %rax
      addq %rax, %rdx
      shrq $3, %rdx
      leaq 8(%rsi,%rdx,8), %rsi
    .L2:
      testq %rcx, %rcx
      je .L1
      addsd (%rsi), %xmm0
    .L1:
      ret
    .L5:
      pxor %xmm0, %xmm0
      jmp .L2
    

    但是,如果我现在将“小数”部分安排在主循环之前(稍后在我发布的答案中挖掘出来),编译器会做得更好。

    #include <stdlib.h>
    
    double sum (size_t n, double *x) {
      size_t nr = n & 1;
      double *end = x + n;
      double sum_x = 0.0;
      if (nr) sum_x += *x;
      for (x += nr; x < end; x++) sum_x += *x;
      return sum_x;
      }
    
    sum:
      leaq (%rsi,%rdi,8), %rdx
      pxor %xmm0, %xmm0
      andl $1, %edi
      je .L2
      addsd (%rsi), %xmm0
    .L2:
      leaq (%rsi,%rdi,8), %rax
      cmpq %rax, %rdx
      jbe .L1
    .L4:
      addsd (%rax), %xmm0
      addq $8, %rax
      cmpq %rax, %rdx
      ja .L4
    .L1:
      ret
    

    我只使用了编译器标志 -O2 是的。所以正如peter所说,编译器生成的程序集应该接近c源代码。那么问题是,为什么编译器在后一种情况下做得更好?

    这其实不是一个与性能相关的问题。这只是我在检查编译器的程序集输出时无意中发现的(而且无法解释)我一直在编写的C项目中的C代码。再次感谢。感谢彼得为这个问题提出了一个更好的题目。


    原始问题:

    下面的小C函数复制 a ,向量 n 条目到 b 是的。应用深度2的手动循环展开。

    #include <stddef.h>
    
    void foo (ptrdiff_t n, double *a, double *b) {
      ptrdiff_t i = 0;
      ptrdiff_t nr = n & 1;
      n -= nr;                  // `n` is an even integer
      while (i < n) {
        b[i] = a[i];
        b[i + 1] = a[i + 1];
        i += 2;
        }                       // `i = n` when the loop ends
      if (nr) b[i] = a[i];
      }
    

    它使x64组件 gcc -O2 (任何 gcc 版本5.4+。但是,我发现输出中被注释的部分很奇怪。为什么编译器会生成它们?

    foo:
      movq %rdi, %rcx
      xorl %eax, %eax
      andl $1, %ecx
      subq %rcx, %rdi
      testq %rdi, %rdi
      jle .L11
    .L12:
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
      movsd 8(%rsi,%rax,8), %xmm0
      movsd %xmm0, 8(%rdx,%rax,8)
      addq $2, %rax
      cmpq %rax, %rdi           // `i` in %rax, `n` in %rdi
      jg .L12                   // the loop ends, with `i = n`, BELOW IS WEIRD
      subq $1, %rdi             // n = n - 1;
      shrq %rdi                 // n = n / 2;
      leaq 2(%rdi,%rdi), %rax   // i = 2 * n + 2;  (this is just `i = n`, isn't it?)
    .L11:
      testq %rcx, %rcx
      je .L10
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
    .L10:
      ret
    

    类似的版本使用 size_t 而不是 ptrdiff_t 给出类似的结果:

    #include <stdlib.h>
    
    void bar (size_t n, double *a, double *b) {
      size_t i = 0;
      size_t nr = n & 1;
      n -= nr;                  // `n` is an even integer
      while (i < n) {
        b[i] = a[i];
        b[i + 1] = a[i + 1];
        i += 2;
        }                       // `i = n` when the loop ends
      if (nr) b[i] = a[i];
      }
    
    bar:
      movq %rdi, %rcx
      andl $1, %ecx
      subq %rcx, %rdi
      je .L20
      xorl %eax, %eax
    .L21:
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
      movsd 8(%rsi,%rax,8), %xmm0
      movsd %xmm0, 8(%rdx,%rax,8)
      addq $2, %rax
      cmpq %rax, %rdi           // `i` in %rax, `n` in %rdi
      ja .L21                   // the loop ends, with `i = n`, BUT BELOW IS WEIRD
      subq $1, %rdi             // n = n - 1;
      andq $-2, %rdi            // n = n & (-2);
      addq $2, %rdi             // n = n + 2;  (this is just `i = n`, isn't it?)
    .L20:
      testq %rcx, %rcx
      je .L19
      movsd (%rsi,%rdi,8), %xmm0
      movsd %xmm0, (%rdx,%rdi,8)
    .L19:
      ret
    

    这是另一个等价物,

    #include <stdlib.h>
    
    void baz (size_t n, double *a, double *b) {
      size_t nr = n & 1;
      n -= nr;
      double *b_end = b + n;
      while (b < b_end) {
        b[0] = a[0];
        b[1] = a[1];
        a += 2;
        b += 2;
        }                       // `b = b_end` when the loop ends
      if (nr) b[0] = a[0];
      }
    

    但是下面的程序集看起来更奇怪(尽管在 -O2 )中。现在 n个 , 都被复制了,当循环结束时,我们用5行代码结束 b_copy = 0 是吗?啊!

    baz:                        // initially, `n` in %rdi, `a` in %rsi, `b` in %rdx
      movq %rdi, %r8            // n_copy = n;
      andl $1, %r8d             // nr = n_copy & 1;
      subq %r8, %rdi            // n_copy -= nr;
      leaq (%rdx,%rdi,8), %rdi  // b_end = b + n;
      cmpq %rdi, %rdx           // if (b >= b_end) jump to .L31
      jnb .L31
      movq %rdx, %rax           // b_copy = b;
      movq %rsi, %rcx           // a_copy = a;
    .L32:
      movsd (%rcx), %xmm0
      addq $16, %rax
      addq $16, %rcx
      movsd %xmm0, -16(%rax)
      movsd -8(%rcx), %xmm0
      movsd %xmm0, -8(%rax)
      cmpq %rax, %rdi           // `b_copy` in %rax, `b_end` in %rdi
      ja .L32                   // the loop ends, with `b_copy = b_end`
      movq %rdx, %rax           // b_copy = b;
      notq %rax                 // b_copy = ~b_copy;
      addq %rax, %rdi           // b_end = b_end + b_copy;
      andq $-16, %rdi           // b_end = b_end & (-16);
      leaq 16(%rdi), %rax       // b_copy = b_end + 16;
      addq %rax, %rsi           // a += b_copy;   (isn't `b_copy` just 0?)
      addq %rax, %rdx           // b += b_copy;
    .L31:
      testq %r8, %r8            // if (nr == 0) jump to .L30
      je .L30
      movsd (%rsi), %xmm0       // xmm0 = a[0];
      movsd %xmm0, (%rdx)       // b[0] = xmm0;
    .L30:
      ret
    

    有人能解释一下编译器在这三种情况下的想法吗?

    2 回复  |  直到 6 年前
        1
  •  1
  •   Zheyuan Li    6 年前

    如果我按以下方式展开循环,编译器可以生成更整洁的代码。

    #include <stdlib.h>
    #include <stddef.h>
    
    void foo (ptrdiff_t n, double *a, double *b) {
      ptrdiff_t i = n & 1;
      if (i) b[0] = a[0];
      while (i < n) {
        b[i] = a[i];
        b[i + 1] = a[i + 1];
        i += 2;
        }
      }
    
    void bar (size_t n, double *a, double *b) {
      size_t i = n & 1;
      if (i) b[0] = a[0];
      while (i < n) {
        b[i] = a[i];
        b[i + 1] = a[i + 1];
        i += 2;
        }
      }
    
    void baz (size_t n, double *a, double *b) {
      size_t nr = n & 1;
      double *b_end = b + n;
      if (nr) b[0] = a[0];
      b += nr;
      while (b < b_end) {
        b[0] = a[0];
        b[1] = a[1];
        a += 2;
        b += 2;
        }
      }
    

    foo:
      movq %rdi, %rax
      andl $1, %eax
      je .L9
      movsd (%rsi), %xmm0
      movsd %xmm0, (%rdx)
      cmpq %rax, %rdi
      jle .L11
    .L4:
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
      movsd 8(%rsi,%rax,8), %xmm0
      movsd %xmm0, 8(%rdx,%rax,8)
      addq $2, %rax
    .L9:
      cmpq %rax, %rdi
      jg .L4
    .L11:
      ret
    

    bar:
      movq %rdi, %rax
      andl $1, %eax
      je .L20
      movsd (%rsi), %xmm0
      movsd %xmm0, (%rdx)
      cmpq %rax, %rdi
      jbe .L21
    .L15:
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
      movsd 8(%rsi,%rax,8), %xmm0
      movsd %xmm0, 8(%rdx,%rax,8)
      addq $2, %rax
    .L20:
      cmpq %rax, %rdi
      ja .L15
    .L21:
      ret
    

    baz:
      leaq (%rdx,%rdi,8), %rcx
      andl $1, %edi
      je .L23
      movsd (%rsi), %xmm0
      movsd %xmm0, (%rdx)
    .L23:
      leaq (%rdx,%rdi,8), %rax
      cmpq %rax, %rcx
      jbe .L22
    .L25:
      movsd (%rsi), %xmm0
      addq $16, %rax
      addq $16, %rsi
      movsd %xmm0, -16(%rax)
      movsd -8(%rsi), %xmm0
      movsd %xmm0, -8(%rax)
      cmpq %rax, %rcx
      ja .L25
    .L22:
      ret
    
        2
  •  0
  •   Gunther Schulz    6 年前

    如果你问为什么程序集比较大,是因为编译器不能假设你可能知道的。

    例如,如果您知道在复制期间不会修改源数组,请通过添加 const 指向源数据的限定符。

    void foo (ptrdiff_t n, double *a, double const *b)
    

    此外,如果您知道这两个内存范围永远不会重叠,请添加 restrict 两个指针的限定符。

    void foo (ptrdiff_t n, double *restrict a, double const *restrict b)
    

    最后,如果您想要最优化的副本(编译器供应商在这方面花费了大量时间),请使用 memcpy 对于非重叠范围,以及 memmove 用于重叠范围。