最直接的实现需要3条指令;
vmovl.s8
,
vmovl.u8
vmlal.s16
,进行有符号加长16位乘法运算,累加到32位寄存器中。作为
vmlal。s16
要对以下4个元素进行乘法和累加,请对4个元素执行4条指令。
对于aarch64语法,相应的指令如下
sxtl
,
uxtl
和
smlal
.
如果输出元素应水平聚合,则不能使用融合乘法累积指令
vmlal
vmovl。s8
,然后是
vmul.i16
vpaddl.s16
(水平聚合两个元素),然后是另一个元素
vpadd.i32
水平获取4个元素的总和。因此,8个输入元素有5条指令,或者一个完整的128位向量有10条指令,然后是一个最终的
vadd.s32
将最终结果累加到累加器。在AArch64上,相当于
vpadd。i32
addp
,可以处理128位向量,因此只需少一条指令。
如果您使用的是Instrinsic,那么实现可能如下所示:
int32x4_t vpdpbusd(int32x4_t sum, uint8x16_t input, int8x16_t weight) {
int16x8_t i1 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input)));
int16x8_t i2 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input)));
int16x8_t w1 = vmovl_s8(vget_low_s8(weight));
int16x8_t w2 = vmovl_s8(vget_high_s8(weight));
int16x8_t p1 = vmulq_s16(i1, w1);
int16x8_t p2 = vmulq_s16(i2, w2);
int32x4_t s1 = vpaddlq_s16(p1);
int32x4_t s2 = vpaddlq_s16(p2);
#if defined(__aarch64__)
int32x4_t s3 = vpaddq_s32(s1, s2);
#else
int32x4_t s3 = vcombine_s32(
vpadd_s32(vget_low_s32(s1), vget_high_s32(s1)),
vpadd_s32(vget_low_s32(s2), vget_high_s32(s2))
);
#endif
sum = vaddq_s32(sum, s3);
return sum;
}