模数乘法函数:在一个模数下将两个整数相乘

问题描述

我在 miller-rabin 素性测试的代码中遇到了这个模乘函数。这应该是为了消除计算 ( a * b ) % m 时发生的整数溢出。

我需要一些帮助来了解这里发生的事情。为什么这样做?那个数字文字 0x8000000000000000ULL 的意义是什么?

unsigned long long mul_mod(unsigned long long a,unsigned long long b,unsigned long long m) {
    unsigned long long d = 0,mp2 = m >> 1;
    if (a >= m) a %= m;
    if (b >= m) b %= m;
    for (int i = 0; i < 64; i++)
    {
        d = (d > mp2) ? (d << 1) - m : d << 1;
        if (a & 0x8000000000000000ULL)
            d += b;
        if (d >= m) d -= m;
        a <<= 1;
    }
    return d;
}

解决方法

此代码目前出现在 the modular arithmetic Wikipedia page 上,仅适用于最多 63 位的参数——见底部。

概述

计算普通乘法a * b的一种方法是添加b的左移副本——a中的每个1位一个。这类似于我们大多数人在学校做长乘法的方式,但简化了:因为我们只需要将 b 的每个副本“乘以”1 或 0,我们需要做的就是添加移位的副本b 的(当 a 的相应位为 1)或什么都不做(当它为 0 时)。

这段代码做了类似的事情。但是,为了避免溢出(主要是;见下文),不是将 b 的每个副本移动然后将其添加到总数中,而是将 b 的未移动副本添加到总数中,并依赖于稍后对总数执行左移以将其移到正确的位置。您可以将这些转变“作用于”到目前为止添加到总数中的所有被加数。例如,第一次循环迭代检查 a 的最高位,即第 63 位是否为 1(这就是 a & 0x8000000000000000ULL 所做的),如果是,则将 b 的未移位副本添加到总数;到循环完成时,前一行代码将总 d 左移 1 位 63 次。

这样做的主要优点是我们总是将已知小于 b 的两个数字(即 dm)相加,因此处理模环绕很便宜:我们知道 b + d < 2 * m,所以为了确保我们到目前为止的总数保持小于 m,检查是否 b + d < m 就足够了,如果不是,减去 {{1} }.如果我们改用先移后加的方法,我们将需要每位 m 模运算,这与除法一样昂贵 - 通常比减法昂贵得多。

模算术的一个特性是,每当我们想对某个数 % 取模执行一系列算术运算时,在通常的算术中执行它们,并取余数取模 m end 总是产生与对每个中间结果取余数模 m 相同的结果(前提是没有发生溢出)。

代码

在循环体的第一行之前,我们有不变量 md < m

线

b < m

是将总 d = (d > mp2) ? (d << 1) - m : d << 1; 左移 1 位的谨慎方法,同时将其保持在 d 范围内并避免溢出。我们不是先移动它然后测试结果是否为 0 .. m 或更大,而是测试它当前是否严格高于 m -- 因为如果是,加倍后,肯定会严格高于 RoundDown(m/2),因此需要减去 2 * RoundDown(m/2) >= m - 1 才能回到范围内。请注意,即使 m 中的 (d << 1) 可能溢出并丢失 (d << 1) - m 的最高位,但这并没有什么害处,因为它不会影响减法结果的最低 64 位,即我们唯一感兴趣的那些。(另请注意,如果 d 正好,我们之后会得到 d == m/2,这有点超出范围——但是将测试从 d == m 更改为d > mp2 来解决这个问题会打破 d >= mp2 是奇数和 m 的情况,所以我们必须忍受这个。没关系,因为它会在下面修复。)

为什么不直接写 d == RoundDown(m/2) 呢?假设,在无限精度算术中,d <<= 1; if (d >= m) d -= m;,所以我们应该执行减法——但是 d << 1 >= m 的高位打开,而 d 的其余部分小于 {{1 }}:在这种情况下,初始移位将丢失高位,d << 1 将无法执行。

限制输入 63 位或更少

上述边缘情况只能在m的高位开启时才会发生,也只有在if的高位也开启时才会发生(因为我们保持不变{{1} })。因此,即使 d 的值非常高,代码也很难正常工作。不幸的是,事实证明它仍然可以溢出其他地方,从而导致某些设置最高位的输入的答案不正确。例如,当 md < mm 时,正确答案应该是 a = 3,但是 the code will return 0x7FFFFFFFFFFFFFFDULL(查看正确答案的简单方法是重新运行 b = 0x7FFFFFFFFFFFFFFFULLm = 0xFFFFFFFFFFFFFFFFULL 的值交换)。具体而言,只要行 0x7FFFFFFFFFFFFFFEULL 溢出并使截断的 a 小于 b,导致错误跳过减法,就会发生此行为。

如果这种行为被记录在案(就像在维基百科页面上一样),这只是一个限制,而不是一个错误。

解除限制

如果我们替换行

d += b

d

the code will work for all inputs,包括设置了最高位的那些。神秘的第一行只是一种无条件(因此通常更快)的书写方式

m

测试 if (a & 0x8000000000000000ULL) d += b; if (d >= m) d -= m; unsigned long long x = -(a >> 63) & b; if (d >= m - x) d -= m; d += x; before 上运行,它被修改了 -- 就像旧的 unsigned long long x = (a & 0x8000000000000000ULL) ? b : 0; 测试,但是 d >= m - x(当d 的最高位打开)或 0(否则)已从两侧减去。这将测试 d >= m 是否会在添加 b 后为 a 或更大。我们知道 RHS d 永远不会下溢,因为最大的 m 可以是 x 并且我们已经在函数的顶部建立了 m - x