问题描述
我有这个 C:
#include <stddef.h>
size_t findChar(unsigned int length,char* __attribute__((aligned(16))) restrict string) {
for (size_t i = 0; i < length; i += 2) {
if (string[i] == '[' || string[i] == ' ') {
return i;
}
}
return -1;
}
它检查字符串的所有其他字符并返回字符串的第一个索引,即 [
或
。使用 x86-64 GCC 10.2 -O3 -march=skylake -mtune=skylake
,这是汇编输出:
findChar:
mov edi,edi
test rdi,rdi
je .L4
xor eax,eax
.L3:
movzx edx,BYTE PTR [rsi+rax]
cmp dl,91
je .L1
cmp dl,32
je .L1
add rax,2
cmp rax,rdi
jb .L3
.L4:
mov rax,-1
.L1:
ret
似乎可以显着优化,因为我看到了多个分支。如何编写 C 以便编译器使用 SIMD、字符串指令和/或矢量化对其进行优化?
Godbolt 上的交互式汇编输出:https://godbolt.org/z/W19Gz8x73
将其更改为具有明确声明长度的 VLA 并没有多大帮助:https://godbolt.org/z/bb5fzbdM1
这是修改后的代码版本,因此该函数将只返回每 100 个字符:https://godbolt.org/z/h8MjbP1cf
解决方法
我不知道如何说服编译器为此发出良好的自动矢量化代码。但我知道如何手动矢量化。由于您正在为 Skylake 编译,这里是您的函数的 AVX2 版本。未经测试。
#include <stddef.h>
#include <immintrin.h>
ptrdiff_t findCharAvx2( size_t length,const char* str )
{
const __m256i andMask = _mm256_set1_epi16( 0xFF );
const __m256i search1 = _mm256_set1_epi16( '[' );
const __m256i search2 = _mm256_set1_epi16( ' ' );
const char* const ptrStart = str;
const char* const ptrEnd = str + length;
const char* const ptrEndAligned = str + ( length / 32 ) * 32;
for( ; str < ptrEndAligned; str += 32 )
{
// Load 32 bytes,zero out half of them
__m256i vec = _mm256_loadu_si256( ( const __m256i * )str );
vec = _mm256_and_si256( andMask,vec );
// Compare 16-bit lanes for equality,combine with OR
const __m256i cmp1 = _mm256_cmpeq_epi16( vec,search1 );
const __m256i cmp2 = _mm256_cmpeq_epi16( vec,search2 );
const __m256i any = _mm256_or_si256( cmp1,cmp2 );
const int mask = _mm256_movemask_epi8( any );
// If neither character is found,mask will be 0.
// Otherwise,the least significant set bit = index of the first matching byte in `any` vector
#ifdef _MSC_VER
unsigned long bitIndex;
// That's how actual instruction works,it returns 2 things at once,flag and index
if( 0 == _BitScanForward( &bitIndex,(unsigned long)mask ) )
continue;
#else
if( 0 == mask )
continue;
const int bitIndex = __builtin_ctz( mask );
#endif
return ( str - ptrStart ) + bitIndex;
}
// Handle the remainder
for( ; str < ptrEnd; str += 2 )
{
const char c = *str;
if( c == '[' || c == ' ' )
return str - ptrStart;
}
return -1;
}