问题描述
我有以下任务: 计算 1 和 N 之间有多少个数字恰好有 K 个零非前导位。 (例如 710=1112 将有 0 个,4 将有 2 个)
N 和 K 满足条件 0 ≤ K,N ≤ 1000000000
这个版本使用 POPCNT 并且在我的机器上足够快:
%include "io.inc"
section .bss
n resd 1
k resd 1
ans resd 1
section .text
global CMAIN
CMAIN:
GET_DEC 4,n
GET_DEC 4,k
mov ecx,1
mov edx,0
;ecx is counter from 1 to n
loop_:
mov eax,ecx
popcnt eax,eax;in eax Now amount of bits set
mov edx,32
sub edx,eax;in edx Now 32-bits set=bits not set
mov eax,ecx;count leading bits
bsr eax,eax;
xor eax,0x1f;
sub edx,eax
mov eax,edx
; all this lines something like (gcc):
; eax=32-__builtin_clz(x)-_mm_popcnt_u32(x);
cmp eax,[k];is there k non-leading bits in ecx?
jnz notk
;if so,then increment ans
mov edx,[ans]
add edx,1
mov [ans],edx
notk:
;increment counter,compare to n and loop
inc ecx
cmp ecx,dword[n]
jna loop_
;print ans
PRINT_DEC 4,ans
xor eax,eax
ret
就速度而言应该没问题(~0.8 秒),但没有被接受,因为(我猜)测试服务器上使用的 cpu 太旧,所以它表明发生了运行时错误。
我尝试对 64K * 4 字节查找表使用预计数技巧,但速度不够快:
%include "io.inc"
section .bss
n resd 1
k resd 1
ans resd 1
wordbits resd 65536; bits set in numbers from 0 to 65536
section .text
global CMAIN
CMAIN:
mov ebp,esp; for correct debugging
mov ecx,0
;mov eax,ecx
;fill in wordbits,ecx is wordbits array index
precount_:
mov eax,ecx
xor ebx,ebx
;c is ebx,v is eax
;for (c = 0; v; c++){
; v &= v - 1; // clear the least significant bit set
;}
lloop_:
mov edx,eax
dec edx
and eax,edx
inc ebx
test eax,eax
jnz lloop_
;computed bits set
mov dword[wordbits+4*ecx],ebx
inc ecx
cmp ecx,65536
jna precount_
;0'th element should be 0
mov dword[wordbits],0
GET_DEC 4,edi;n
GET_DEC 4,esi;k
mov ecx,1
xor edx,edx
xor ebp,ebp
loop_:
mov eax,ecx
;popcnt eax,eax
mov edx,ecx
and eax,0xFFFF
shr edx,16
mov eax,dword[wordbits+4*eax]
add eax,dword[wordbits+4*edx]
;prevIoUs lines are to implement absent instruction popcnt.
; they simply do eax=wordbits[x & 0xFFFF] + wordbits[x >> 16]
mov edx,eax
;and the same as before:
;non-leading zero bits=32-bits set-__builtin_clz(x)
mov eax,ecx
bsr eax,eax
xor eax,0x1f
sub edx,edx
;compare to k again to see if this number has exactly k
;non-leading zero bits
cmp edx,esi
jnz notk
;increment ebp (answer) if so
mov edx,ebp
add edx,1
mov ebp,edx
;and (or) go to then next iteration
notk:
inc ecx
cmp ecx,edi
jna loop_
;print answer what is in ebp
PRINT_DEC 4,ebp
xor eax,eax
ret
(>1 秒)
我应该加速第二个程序(如果是,那么如何?)或者以某种方式用其他(哪些?)指令替换 POPCNT(我猜 SSE2 和更早版本应该可用)?
解决方法
这是算法优化的尝试。
我。 [0;] 范围内所需整数的数量2 **楼(log2(N)))
所有这些整数都小于 N
,因此我们只需要检查它们中有多少在前导一位以下正好有 K
个零位。
对于位长为 n
的整数,有 n - 1
个可能的位置来放置我们的零(位低于前导一位)。因此,位长为 n
的所需整数的数量是从 k
位(无重复,无序)中挑选 n - 1
个零的方法数。我们可以使用 binomial coefficient 公式计算:
n! / (k! * (n - k)!)
如果我们使用 32 位整数,那么 n
的最大可能值为 31(k
也是如此)。 31 的因数仍然很大,即使是 64 位数字也不能容纳,因此我们必须执行重复除法(可以在编译时预先计算 constexpr
)。
为了得到整数的总数,我们计算了 n
从 1 到 floor(log2(N))
的二项式系数并将它们相加。
二。范围内所需整数的数量 [2 ** floor(log2(N)); N]
从前导一位之后的位开始。并应用以下算法:
-
如果当前位为零,那么我们不能对该位做任何事情(它必须为零,如果我们将其更改为一,则整数值变得大于
N
),所以我们只需减少零位预算K
并移至下一位。 -
如果当前位是 1,那么我们可以假装它是 0。现在剩余的低有效位的任何组合都将适合低于
N
的范围。获取二项式系数值以计算从剩余位数中挑选剩余零数并添加到总数中的方法。
一旦我们用完比特或 K
变为零,算法就会停止。此时,如果 K
等于剩余位数,这意味着我们可以将它们清零以获得所需的整数,因此我们将总计数加一(计数 N
本身对总数)。或者,如果 K
为零且所有剩余位均为 1,那么我们也可以将 N
计入总数。
代码:
#include <stdio.h>
#include <chrono>
template<typename T>
struct Coefficients {
static constexpr unsigned size_v = sizeof(T) * 8;
// Zero-initialize.
// Indexed by [number_of_zeros][number_of_bits]
T value[size_v][size_v] = {};
constexpr Coefficients() {
// How many different ways we can choose k items from n items
// without order and without repetition.
//
// n! / k! (n - k)!
value[0][0] = 1;
value[0][1] = 1;
value[1][1] = 1;
for(unsigned i = 2; i < size_v; ++i) {
value[0][i] = 1;
value[1][i] = i;
T r = i;
for(unsigned j = 2; j < i; ++j) {
r = (r * (i - j + 1)) / j;
value[j][i] = r;
}
value[i][i] = 1;
}
}
};
template<typename T>
__attribute__((noinline)) // To make it easier to benchmark
T count_combinations(T max_value,T zero_bits) {
if( max_value == 0 )
return 0;
constexpr int size = sizeof(T) * 8;
constexpr Coefficients<T> coefs;
// assert(zeros_bits < size)
int bits = size - __builtin_clz(max_value);
T total = 0;
// Count all-ones count.
#pragma clang loop vectorize(disable)
for(int i = 0; i < bits - 1; ++i) {
total += coefs.value[zero_bits][i];
}
// Count interval [2**bits,max_value]
bits -= 1;
T mask = T(1) << bits;
max_value &= ~mask; // Remove leading bit
mask = mask >> 1;
#pragma clang loop vectorize(disable)
while( zero_bits && zero_bits < bits ) {
if( max_value & mask ) {
// If current bit is one,then we can pretend that it is zero
// (which would only make the value smaller,which means that
// it would still be < max_value) and grab all combinations of
// zeros within the remaining bits.
total += coefs.value[zero_bits - 1][bits - 1];
// And then stop pretending it's zero and continue as normal.
} else {
// If current bit is zero,we can't do anything about it,just
// have to spend a zero from our budget.
zero_bits--;
}
max_value &= ~mask;
mask = mask >> 1;
bits--;
}
// At this point we don't have any more zero bits,or we don't
// have any more bits at all.
if( (zero_bits == bits) ||
(zero_bits == 0 && max_value == ((mask << 1) - 1)) ) {
total++;
}
return total;
}
int main() {
using namespace std::chrono;
unsigned count = 0;
time_point t0 = high_resolution_clock::now();
for(int i = 0; i < 1000; ++i) {
count |= count_combinations<unsigned>(1'000'000'000,8);
}
time_point t1 = high_resolution_clock::now();
auto duration = duration_cast<nanoseconds>(t1 - t0).count();
printf("result = %u,time = %lld ns\n",count,duration / 1000);
return 0;
}
结果(N=1'000'000'000,K=8,在 i7-9750H 上运行):
result = 12509316,time = 35 ns
如果在运行时计算二项式系数,则需要 ~3.2 µs。