二进制搜索:如何确定数组的一半

问题描述

这两个公式有什么区别

mid = low + (high - low) / 2;


mid = (high + low) / 2;

解决方法

在第二版中,如果high + low大于int的最大值(假设highint),则它可能会溢出,从而引发未定义的行为。这个特定的错误已通过第一版解决。

第一个版本仍然存在问题,例如如果low是一个非常大的负数,则差值仍会溢出。

从c ++ 20开始,您应该为此使用std::midpoint,它可以处理大量的极端情况,并且对所有情况都适用。

这个看似简单的功能实际上难以实现,实际上,Marshall Clow在cppcon 2019上给出了一个小时的talk,其中涵盖了该功能的实现。

,

第一个是上等的(尽管还不够完善,请参见Binary Search: how to determine half of the array):

  1. 它适用于以下情况:未为highlow定义加法,而是为low添加间隔而定义。指针就是一个这样的例子,日期类型的对象可以是另一个。

  2. high + low会使类型溢出。对于带符号整数类型,其行为是不确定的。

,

两者都可能发生溢出。有符号整数溢出是未定义行为(UB)。

使用 unsigned 数学(通常在数组索引中使用),然后在low <= high时,low + (high - low) / 2;不会像潜在的(high + low) / 2那样溢出。

low <= high0 <= low时与 signed 数学相同。

为避免带符号数学(或带low > high无符号数学)任何溢出并且仍然仅使用int/unsigned数学,我认为以下内容会起作用。

mid = high/2 + low/2 + (high%2 + low%2)/2;

high/2 + low/2的符号不同于(high%2 + low%2)的符号时,可能会失败。

下面是一个更强大且经过测试的版本。也许我以后再简化。

#include <limits.h>
#include <stdio.h>

int midpoint(int a,int b) {
  int avg = a/2 + b/2;
  int small_sum = a%2 + b%2;
  avg += small_sum/2;
  small_sum %= 2;
  if (avg < 0) {
    if (small_sum > 0) avg++;
  } else if (avg > 0) {
    if (small_sum < 0) avg--;
  }
  return avg;
}

int midpoint_test(int a,int b) {
  intmax_t lavg = ((intmax_t)a + (intmax_t)b)/2;
  int avg = midpoint(a,b);
  printf("a:%12d b:%12d avg_wide_math:%12jd avg_midpoint:%12d\n",a,b,lavg,avg);
  return lavg == avg;
}

int main(void) {
  int a[] = {INT_MIN,INT_MIN+1,-100,-99,-2,-1,1,2,99,100,INT_MAX-1,INT_MAX};
  int n = sizeof a/ sizeof a[0];
  for (int i=0; i<n; i++) {
    for (int j=0; j<n; j++) {
      if (midpoint_test(a[i],a[j]) == 0) {
        puts("Oops");
        return 1;
      }
    }
  }
  puts("Success");
  return 0;
}
,

两个公式不同:

  • 两者都可能溢出,具体取决于lowhigh的值。
  • 即使没有溢出,它们也不一定会产生相同的结果:第一个计算中点,第二个计算两个数的平均值。

在接下来的讨论中,我们将假设lowmidhigh具有相同的类型。我们正在寻找一种安全的方法来找到lowhigh之间的中点或平均值,该值始终在类型范围内。

第一个公式mid = low + (high - low) / 2;如果已签名,将四舍五入到low,如果已签名并且highlow太远,则可能会溢出。 / p>

第二个公式mid = (high + low) / 2;会四舍五入到0,但是对于有符号和无符号类型的较大值high和/或low,可能会溢出。

在您的特定应用程序中,计算排序数组的中间元素的索引以执行二进制搜索,索引值lowhigh为非负数,low <= high。在这种约束下,两个公式都可以计算出相同的结果,但是第二个公式可能会溢出,而第一个公式不会。

因此,您应该使用 mid = low + (high - low) / 2; 作为mid = (high + low) / 2;的安全替代品。

通常情况下,计算平均值(第二个公式)而不会发生溢出是一个棘手的问题。下面是一组针对平均公式的解决方案,以及一个受chux回答启发的测试程序。它们可以适用于任何有符号整数类型:

#include <limits.h>
#include <stdio.h>
#include <stdint.h>

int average_chqrlie(int a,int b) {
    if (a <= b) {
        if (a >= 0)
            return a + ((b - a) >> 1);
        if (b < 0)
            return b - ((b - a) >> 1);
    } else {
        if (b >= 0)
            return b + ((a - b) >> 1);
        if (a < 0)
            return a - ((a - b) >> 1);
    }
    return (a + b) / 2;
}

int average_chqrlie2(int a,int b) {
    if (a > b) {
        int tmp = a;
        a = b;
        b = tmp;
    }
    if (a >= 0)
        return a + ((b - a) >> 1);
    if (b < 0)
        return b - ((b - a) >> 1);
    return (a + b) / 2;
}

int average_chqrlie3(int a,int b) {
    int half,mid;
    if (a < b) {
        half = (int)(((unsigned)b - (unsigned)a) / 2);
        mid = a + half;
        if (mid < 0)
            mid = b - half;
    } else {
        half = (int)(((unsigned)a - (unsigned)b) / 2);
        mid = b + half;
        if (mid < 0)
            mid = a - half;
    }
    return mid;
}

int average_chux(int a,int b) {
    int avg = a / 2 + b / 2;
    int small_sum = a % 2 + b % 2;
    avg += small_sum / 2;
    small_sum %= 2;
    if (avg < 0) {
        if (small_sum > 0)
            avg++;
    } else if (avg > 0) {
        if (small_sum < 0)
            avg--;
    }
    return avg;
}

int run_tests(const char *name,int (*fun)(int a,int b)) {
    int array[] = { INT_MIN,INT_MAX };
    int n = sizeof(array) / sizeof(array[0]);
    int status = 0;
    printf("Testing %s:",name);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            int a = array[i],b = array[j];
            intmax_t lavg = ((intmax_t)a + (intmax_t)b) / 2;  // assuming sizeof(intmax_t) > size(int)
            int avg = fun(a,b);
            if (lavg != avg) {
                printf("\na:%12d  b:%12d  average_wide:%12jd  average:%12d",avg);
                status = 1;
            }
        }
    }
    puts(status ? "\nFailed" : " Success");
    return status;
}

int main() {
    run_tests("average_chqrlie",average_chqrlie);
    run_tests("average_chqrlie2",average_chqrlie2);
    run_tests("average_chqrlie3",average_chqrlie3);
    run_tests("average_chux",average_chux);
    return 0;
}
,

与第二个不同,第一个对于较大的低/高值不会导致溢出。始终首选使用mid = low + (high - low) / 2