问题描述
我使用gmplib来获取大数字,然后计算数值(数字的总和:123
-> 6
,74
-> 11
-> { {1}})
这是我所做的:
2
它运行良好,但是在Xeon w-3235上有更快的方法吗?
解决方法
您可以使用如下代码。该算法的总体思路是:
- 按字节处理数据,直到达到高速缓存行对齐
- 一次读取一个缓存行,检查字符串的结尾,然后将数字添加到8个累加器中
- 将8个累加器减为1,并从头算起
- 按字节处理余数
请注意,以下代码尚未经过测试。
// getnumericvalue(ptr)
.section .text
.type getnumericvalue,@function
.globl getnumericvalue
getnumericvalue:
xor %eax,%eax // digit counter
// process string until we reach cache-line alignment
test $64-1,%dil // is ptr aligned to 64 byte?
jz 0f
1: movzbl (%rdi),%edx // load a byte from the string
inc %rdi // advance pointer
test %edx,%edx // is this the NUL byte?
jz .Lend // if yes,finish this function
sub $'0',%edx // turn ASCII character into digit
add %edx,%eax // and add to counter
test $64-1,%dil // is ptr aligned to 64 byte?
jnz 1b // if not,process more data
// process data in cache line increments until the end
// of the string is found somewhere
0: vpbroadcastd zero(%rip),%zmm1 // mask of '0' characters
vpxor %xmm3,%xmm3,%xmm3 // vectorised digit counter
vmovdqa32 (%rdi),%zmm0 // load one cache line from the string
vptestmb %zmm0,%zmm0,%k0 // clear k0 bits if any byte is NUL
kortestq %k0,%k0 // clear CF if a NUL byte is found
jnc 0f // skip loop if a NUL byte is found
.balign 16
1: add $64,%rdi // advance pointer
vpsadbw %zmm1,%zmm0 // sum groups of 8 bytes into 8 words
// also subtracts '0' from each byte
vpaddq %zmm3,%zmm3 // add to counters
vmovdqa32 (%rdi),%k0 // clear CF if a NUL byte is found
jc 1b // go on unless a NUL byte was found
// reduce 8 vectorised counters into rdx
0: vextracti64x4 $1,%zmm3,%ymm2 // extract high 4 words
vpaddq %ymm2,%ymm3,%ymm3 // and add them to the low words
vextracti128 $1,%xmm2 // extract high 2 words
vpaddq %xmm2,%xmm3 // and add them to the low words
vpshufd $0x4e,%xmm2 // swap qwords into xmm2
vpaddq %xmm2,%xmm3 // and add to xmm0
vmovq %xmm3,%rdx // move digit counter back to rdx
add %rdx,%rax // and add to counts from scalar head
// process tail
1: movzbl (%rdi),%edx // turn ASCII character into digit
add %rdx,%rax // and add to counter
jnz 1b // if not,process more data
.Lend: xor %edx,%edx // zero-extend RAX into RDX:RAX
mov $9,%ecx // divide by 9
div %rcx // perform division
mov %edx,%eax // move remainder to result register
test %eax,%eax // is the remainder zero?
cmovz %ecx,%eax // if yes,set remainder to 9
vzeroupper // restore SSE performance
ret // and return
.size getnumericvalue,.-getnumericvalue
// constants
.section .rodata
.balign 4
zero: .byte '0','0','0'
,
这是一个便携式解决方案:
- 它会天真地处理前几位数字,直到
`Caused by: java.lang.IllegalStateException: ScrollView can host only one direct child`
正确对齐为止。 - 然后,它一次循环读取8位数字,并将这些数字成对累加到一个累加器中。在将64位累加器拆分为
ptr
之前,最多可以执行28种这样的操作。 - 终止测试验证包中的所有数字是否具有等于
number
的高半字节。 - 其余数字一一处理。
3
1亿位数的计时:
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
unsigned getnumericvalue_simple(const char *in_str) {
unsigned long number = 0;
const char *ptr = in_str;
do {
if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
ptr++;
} while (*ptr != 0);
return number <= 9 ? number : ((number - 1) % 9) + 1;
}
unsigned getnumericvalue_naive(const char *ptr) {
unsigned long number = 0;
while (*ptr) {
number += *ptr++ - '0';
}
return number ? 1 + (number - 1) % 9 : 0;
}
unsigned getnumericvalue_parallel(const char *ptr) {
unsigned long long number = 0;
unsigned long long pack1,pack2;
/* align source on ull boundary */
while ((uintptr_t)ptr & (sizeof(unsigned long long) - 1)) {
if (*ptr == '\0')
return number ? 1 + (number - 1) % 9 : 0;
number += *ptr++ - '0';
}
/* scan 8 bytes at a time */
for (;;) {
pack1 = 0;
#define REP8(x) x;x;x;x;x;x;x;x
#define REP28(x) x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x
REP28(pack2 = *(const unsigned long long *)(const void *)ptr;
pack2 -= 0x3030303030303030;
if (pack2 & 0xf0f0f0f0f0f0f0f0)
break;
ptr += sizeof(unsigned long long);
pack1 += pack2);
REP8(number += pack1 & 0xFF; pack1 >>= 8);
}
REP8(number += pack1 & 0xFF; pack1 >>= 8);
/* finish trailing bytes */
while (*ptr) {
number += *ptr++ - '0';
}
return number ? 1 + (number - 1) % 9 : 0;
}
int main(int argc,char *argv[]) {
clock_t start;
unsigned naive_result,simple_result,parallel_result;
double naive_time,simple_time,parallel_time;
int digits = argc < 2 ? 1000000 : strtol(argv[1],NULL,0);
char *p = malloc(digits + 1);
for (int i = 0; i < digits; i++)
p[i] = "0123456789123456"[i & 15];
p[digits] = '\0';
start = clock();
simple_result = getnumericvalue_simple(p);
simple_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;
start = clock();
naive_result = getnumericvalue_naive(p);
naive_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;
start = clock();
parallel_result = getnumericvalue_parallel(p);
parallel_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;
printf("simple: %d digits -> %u,%7.3f msec\n",digits,simple_time);
printf("naive: %d digits -> %u,naive_result,naive_time);
printf("parallel: %d digits -> %u,parallel_result,parallel_time);
return 0;
}
请注意,发布的版本中的额外测试不正确,因为simple: 100000000 digits -> 3,100.380 msec
naive: 100000000 digits -> 3,98.128 msec
parallel: 100000000 digits -> 3,7.848 msec
应该产生getnumericvalue("9")
,而不是9
。
并行版本比简单版本快 12倍。
使用AVX指令通过编译器内部函数或什至汇编语言可以获得更高的性能,但是对于非常大的数组,内存带宽似乎是限制因素。