如何提高大整数的乘法效率?

问题描述

这个周末我跟着 wiki 实现了基本的大整数乘法。我使用Toom-3算法来实现。但是一开始花的时间竟然比长乘法(小学乘法)慢,一去不复返了。我希望程序能在500位以内超过小学乘法,请问怎么办?

我尝试优化,我保留了向量容量并删除了多余的代码。但不是很有效。

我应该使用 vector<long long> 作为我的基数吗?

Github 中的整个源代码:

typedef long long BigIntBase;
typedef vector<BigIntBase> BigIntDigits;

// ceil(numeric_limits<BigIntBase>::digits10 / 2.0) - 1;
static const int digit_base_len = 9;
// b
static const BigIntBase digit_base = 1000000000;

class BigInt {

public:
  BigInt(int digits_capacity = 0,bool nega = false) {
    negative = nega;
    digits.reserve(digits_capacity);
  }

  BigInt(BigIntDigits _digits,bool nega = false) {
    negative = nega;
    digits = _digits;
  }

  BigInt(const span<const BigIntBase> &range,bool nega = false) {
    negative = nega;
    digits = BigIntDigits(range.begin(),range.end());
  }

  BigInt operator+(const BigInt &rhs) {
    if ((*this).negative == rhs.negative)
      return BigInt(plus((*this).digits,rhs.digits),(*this).negative);

    if (greater((*this).digits,rhs.digits))
      return BigInt(minus((*this).digits,(*this).negative);

    return BigInt(minus(rhs.digits,(*this).digits),rhs.negative);
  }

  BigInt operator-(const BigInt &rhs) { return *this + BigInt(rhs.digits,!rhs.negative); }

  BigInt operator*(const BigInt &rhs) {
    if ((*this).digits.empty() || rhs.digits.empty()) {
      return BigInt();
    } else if ((*this).digits.size() == 1 && rhs.digits.size() == 1) {
      BigIntBase val = (*this).digits[0] * rhs.digits[0];
      return BigInt(val < digit_base ? BigIntDigits{val} : BigIntDigits{val % digit_base,val / digit_base},(*this).negative ^ rhs.negative);
    } else if ((*this).digits.size() == 1)
      return BigInt(multiply(rhs,(*this).digits[0]).digits,(*this).negative ^ rhs.negative);
    else if (rhs.digits.size() == 1)
      return BigInt(multiply((*this),rhs.digits[0]).digits,(*this).negative ^ rhs.negative);

    return BigInt(toom3(span((*this).digits),span(rhs.digits)),(*this).negative ^ rhs.negative);
  }

  string to_string() {
    if (this->digits.empty())
      return "0";

    stringstream ss;
    if (this->negative)
      ss << "-";

    ss << std::to_string(this->digits.back());
    for (auto it = this->digits.rbegin() + 1; it != this->digits.rend(); ++it)
      ss << setw(digit_base_len) << setfill('0') << std::to_string(*it);

    return ss.str();
  }

  BigInt from_string(string s) {
    digits.clear();
    negative = s[0] == '-';
    for (int pos = max(0,(int)s.size() - digit_base_len); pos >= 0; pos -= digit_base_len)
      digits.push_back(stoll(s.substr(pos,digit_base_len)));

    if (s.size() % digit_base_len)
      digits.push_back(stoll(s.substr(0,s.size() % digit_base_len)));

    return *this;
  }

private:
  bool negative;
  BigIntDigits digits;

  const span<const BigIntBase> toom3_slice_num(const span<const BigIntBase> &num,const int &n,const int &i) {
    int begin = n * i;
    if (begin < num.size()) {
      const span<const BigIntBase> result = num.subspan(begin,min((int)num.size() - begin,i));
      return result;
    }

    return span<const BigIntBase>();
  }

  BigIntDigits toom3(const span<const BigIntBase> &num1,const span<const BigIntBase> &num2) {
    int i = ceil(max(num1.size() / 3.0,num2.size() / 3.0));
    const span<const BigIntBase> m0 = toom3_slice_num(num1,i);
    const span<const BigIntBase> m1 = toom3_slice_num(num1,1,i);
    const span<const BigIntBase> m2 = toom3_slice_num(num1,2,i);
    const span<const BigIntBase> n0 = toom3_slice_num(num2,i);
    const span<const BigIntBase> n1 = toom3_slice_num(num2,i);
    const span<const BigIntBase> n2 = toom3_slice_num(num2,i);

    BigInt pt0 = plus(m0,m2);
    BigInt pp0 = m0;
    BigInt pp1 = plus(pt0.digits,m1);
    BigInt pn1 = pt0 - m1;
    BigInt pn2 = multiply(pn1 + m2,2) - m0;
    BigInt pin = m2;

    BigInt qt0 = plus(n0,n2);
    BigInt qp0 = n0;
    BigInt qp1 = plus(qt0.digits,n1);
    BigInt qn1 = qt0 - n1;
    BigInt qn2 = multiply(qn1 + n2,2) - n0;
    BigInt qin = n2;

    BigInt rp0 = pp0 * qp0;
    BigInt rp1 = pp1 * qp1;
    BigInt rn1 = pn1 * qn1;
    BigInt rn2 = pn2 * qn2;
    BigInt rin = pin * qin;

    BigInt r0 = rp0;
    BigInt r4 = rin;
    BigInt r3 = divide(rn2 - rp1,3);
    BigInt r1 = divide(rp1 - rn1,2);
    BigInt r2 = rn1 - rp0;
    r3 = divide(r2 - r3,2) + multiply(rin,2);
    r2 = r2 + r1 - r4;
    r1 = r1 - r3;

    BigIntDigits result = r0.digits;
    if (!r1.digits.empty()) {
      shift_left(r1.digits,i);
      result = plus(result,r1.digits);
    }

    if (!r2.digits.empty()) {
      shift_left(r2.digits,i << 1);
      result = plus(result,r2.digits);
    }

    if (!r3.digits.empty()) {
      shift_left(r3.digits,i * 3);
      result = plus(result,r3.digits);
    }

    if (!r4.digits.empty()) {
      shift_left(r4.digits,i << 2);
      result = plus(result,r4.digits);
    }

    return result;
  }

  BigIntDigits plus(const span<const BigIntBase> &lhs,const span<const BigIntBase> &rhs) {
    if (lhs.empty())
      return BigIntDigits(rhs.begin(),rhs.end());

    if (rhs.empty())
      return BigIntDigits(lhs.begin(),lhs.end());

    int max_length = max(lhs.size(),rhs.size());
    BigIntDigits result;
    result.reserve(max_length + 1);

    for (int w = 0; w < max_length; ++w)
      result.push_back((lhs.size() > w ? lhs[w] : 0) + (rhs.size() > w ? rhs[w] : 0));

    for (int w = 0; w < result.size() - 1; ++w) {
      result[w + 1] += result[w] / digit_base;
      result[w] %= digit_base;
    }

    if (result.back() >= digit_base) {
      result.push_back(result.back() / digit_base);
      result[result.size() - 2] %= digit_base;
    }

    return result;
  }

  BigIntDigits minus(const span<const BigIntBase> &lhs,lhs.end());

    BigIntDigits result;
    result.reserve(lhs.size() + 1);

    for (int w = 0; w < lhs.size(); ++w)
      result.push_back((lhs.size() > w ? lhs[w] : 0) - (rhs.size() > w ? rhs[w] : 0));

    for (int w = 0; w < result.size() - 1; ++w)
      if (result[w] < 0) {
        result[w + 1] -= 1;
        result[w] += digit_base;
      }

    while (!result.empty() && !result.back())
      result.pop_back();

    return result;
  }

  void shift_left(BigIntDigits &lhs,const int n) {
    if (!lhs.empty()) {
      BigIntDigits zeros(n,0);
      lhs.insert(lhs.begin(),zeros.begin(),zeros.end());
    }
  }

  BigInt divide(const BigInt &lhs,const int divisor) {
    BigIntDigits reminder(lhs.digits);
    BigInt result(lhs.digits.capacity(),lhs.negative);

    for (int w = reminder.size() - 1; w >= 0; --w) {
      result.digits.insert(result.digits.begin(),reminder[w] / divisor);
      reminder[w - 1] += (reminder[w] % divisor) * digit_base;
    }

    while (!result.digits.empty() && !result.digits.back())
      result.digits.pop_back();

    return result;
  }

  BigInt multiply(const BigInt &lhs,const int multiplier) {
    BigInt result(lhs.digits,lhs.negative);

    for (int w = 0; w < result.digits.size(); ++w)
      result.digits[w] *= multiplier;

    for (int w = 0; w < result.digits.size(); ++w)
      if (result.digits[w] >= digit_base) {
        if (w + 1 == result.digits.size())
          result.digits.push_back(result.digits[w] / digit_base);
        else
          result.digits[w + 1] += result.digits[w] / digit_base;
        result.digits[w] %= digit_base;
      }

    return result;
  }

  bool greater(const BigIntDigits &lhs,const BigIntDigits &rhs) {
    if (lhs.size() == rhs.size()) {
      int w = lhs.size() - 1;
      while (w >= 0 && lhs[w] == rhs[w])
        --w;

      return w >= 0 && lhs[w] > rhs[w];
    } else
      return lhs.size() > rhs.size();
  }
};
数字 小学 Toom-3
10 4588 10003
50 24147 109084
100 52165 286535
150 92405 476275
200 172156 1076570
250 219599 1135946
300 320939 1530747
350 415655 1689745
400 498172 1937327
450 614467 2629886
500 863116 3184277

解决方法

问题是你在 toom3_slice_num 中进行了一百万次分配,在这里你可以使用一个 std::span(或一个 std::pair 迭代器到实际部分)作为你给出的数字是一个常量。 toom3 也是分配器地狱。

multiply 可能会再分配 1 个时间。计算所需的位数或只是将大小加 1。

对于几乎无锁的分配,vector 应该是 pmr(使用适当的分配器)。

如果不使用 -O2-O3 编译,所有这些都将被浪费。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...