为什么使用 armv8 NEON 指令进行矩阵乘法 (float32_4x4) 速度较慢?

问题描述

以下代码使用 NEON 指令(来自 UE4)

void matrixMultiplyNeon(float* ret,float32x4_t* A,float32x4_t* B) {

    float32x4_t * R = (float32x4_t*)ret;
    float32x4_t temp,r0,r1,r2,r3;

    auto low  = vget_low_f32(A[0]);
    auto high = vget_high_f32(A[0]);
    temp = vmulq_lane_f32(      B[0],low,0);
    temp = vmlaq_lane_f32(temp,B[1],1);
    temp = vmlaq_lane_f32(temp,B[2],high,0);
    r0   = vmlaq_lane_f32(temp,B[3],1);

    low  = vget_low_f32(A[1]);
    high = vget_high_f32(A[1]);
    temp = vmulq_lane_f32(      B[0],0);
    r1   = vmlaq_lane_f32(temp,1);

    low  = vget_low_f32(A[2]);
    high = vget_high_f32(A[2]);
    temp = vmulq_lane_f32(      B[0],0);
    r2   = vmlaq_lane_f32(temp,1);

    low  = vget_low_f32(A[3]);
    high = vget_high_f32(A[3]);
    temp = vmulq_lane_f32(      B[0],0);
    r3   = vmlaq_lane_f32(temp,1);

    R[0] = r0;
    R[1] = r1;
    R[2] = r2;
    R[3] = r3;
}

下面的代码是我的普通矩阵乘法,使用数组 float[16]

void matrixMultiply(float* ret,float* m1,float* m2) {
    float product[16];
    product[0] = m1[0] * m2[0] + m1[4] * m2[1] + m1[8] * m2[2] + m1[12] * m2[3];
    product[1] = m1[1] * m2[0] + m1[5] * m2[1] + m1[9] * m2[2] + m1[13] * m2[3];
    product[2] = m1[2] * m2[0] + m1[6] * m2[1] + m1[10] * m2[2] + m1[14] * m2[3];
    product[3] = m1[3] * m2[0] + m1[7] * m2[1] + m1[11] * m2[2] + m1[15] * m2[3];
    product[4] = m1[0] * m2[4] + m1[4] * m2[5] + m1[8] * m2[6] + m1[12] * m2[7];
    product[5] = m1[1] * m2[4] + m1[5] * m2[5] + m1[9] * m2[6] + m1[13] * m2[7];
    product[6] = m1[2] * m2[4] + m1[6] * m2[5] + m1[10] * m2[6] + m1[14] * m2[7];
    product[7] = m1[3] * m2[4] + m1[7] * m2[5] + m1[11] * m2[6] + m1[15] * m2[7];
    product[8]  = m1[0] * m2[8] + m1[4] * m2[9] + m1[8] * m2[10] + m1[12] * m2[11];
    product[9]  = m1[1] * m2[8] + m1[5] * m2[9] + m1[9] * m2[10] + m1[13] * m2[11];
    product[10] = m1[2] * m2[8] + m1[6] * m2[9] + m1[10] * m2[10] + m1[14] * m2[11];
    product[11] = m1[3] * m2[8] + m1[7] * m2[9] + m1[11] * m2[10] + m1[15] * m2[11];
    product[12] = m1[0] * m2[12] + m1[4] * m2[13] + m1[8] * m2[14] + m1[12] * m2[15];
    product[13] = m1[1] * m2[12] + m1[5] * m2[13] + m1[9] * m2[14] + m1[13] * m2[15];
    product[14] = m1[2] * m2[12] + m1[6] * m2[13] + m1[10] * m2[14] + m1[14] * m2[15];
    product[15] = m1[3] * m2[12] + m1[7] * m2[13] + m1[11] * m2[14] + m1[15] * m2[15];
    memcpy(ret,product,sizeof(float) * 16);
}

测试是一个 1024*1024 次的 for 循环,结果是: 没有霓虹灯 366 毫秒 霓虹灯428毫秒

为什么 NEON 代码更慢以及如何优化?

解决方法

一旦您使用 vget_low 和/或 vget_high,编译器就会产生混乱。

Neon 内在函数仅可用于处理输入和输出通常匹配 1:1 的连续数据。对于排列,您最好用汇编语言编写代码。

顺便说一句,你应该用 vld4 转置矩阵 A。

并考虑使用普通的 float 指针,而不是 float32x4_t *

void matrixMultiplyNeon(float* ret,float* A,float* B) {

    float32x4x4_t matA,matB,rslt;

    matA = vld4q_f32(A);
    matB.val[0] = vld1q_f32(B);
    B += 4;
    matB.val[1] = vld1q_f32(B);
    B += 4;
    matB.val[2] = vld1q_f32(B);
    B += 4;
    matB.val[3] = vld1q_f32(B);

    rslt.val[0] = matA.val[0] * matB.val[0];
    rslt.val[0] += matA.val[1] * matB.val[0];
    rslt.val[0] += matA.val[2] * matB.val[0];
    rslt.val[0] += matA.val[3] * matB.val[0];

    rslt.val[1] = matA.val[0] * matB.val[1];
    rslt.val[1] += matA.val[1] * matB.val[1];
    rslt.val[1] += matA.val[2] * matB.val[1];
    rslt.val[1] += matA.val[3] * matB.val[1];

    rslt.val[2] = matA.val[0] * matB.val[2];
    rslt.val[2] += matA.val[1] * matB.val[2];
    rslt.val[2] += matA.val[2] * matB.val[2];
    rslt.val[2] += matA.val[3] * matB.val[2];

    rslt.val[3] = matA.val[0] * matB.val[3];
    rslt.val[3] += matA.val[1] * matB.val[3];
    rslt.val[3] += matA.val[2] * matB.val[3];
    rslt.val[3] += matA.val[3] * matB.val[3];

    vst4q_f32(ret,rslt);
}