一种可操作性很强的二分查找思路

今天看到HuaHua的视频,总结二分查找。觉得视角很好,于是在其基础上,再用做过的题复习一下。

标准的二分查找的核心视角是,存在一个函数g(x),x的取值区间被分为两个部分,右侧g(x)为True,左侧g(x)为False。二分查找的目的。是在x取值区间内找到一个最小值m,使得g(m) = True

gx.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def right_bisect(hay, needle):  # upper-bound
lo, hi = 0, len(hay)
while lo < hi: # 左闭右开区间,只有严格小于时才表示搜索区间为空
mid = lo + (hi - lo) // 2
if hay[mid] > needle: # g(x)的意义就是找到大于needle的最小index
hi = mid # g(x)为True,那说明在成立的一侧,最小值在左边
else:
lo = mid + 1
return lo # 始终没有找到使g(x)=True,返回hi

def left_bisect(hay, needle): # lower-bound
lo, hi = 0, len(hay)
while lo < hi:
mid = lo + (hi - lo) // 2
if hay[mid] >= needle: # g(x)的意义就是找到大于等于needle的最小index
hi = mid
else:
lo = mid + 1
return lo

# 如果不存在是m使g(x)成立,返回hi

除了左闭右开区间,还有就是闭区间,本质没有区别,搜索的停止条件也和闭区间的意义相符合。而且因为对称性、可以直接使用hi索引,在一些变形里用起来还更方便一些

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def bisect_right(hay, needle):
lo, hi = 0, len(hay)-1
while lo <= hi: # 只有lo > hi才表示搜索空间为空
mid = lo + (hi - lo) // 2
if hay[mid] > needle: # g(m)=True的意义是m大于needle
hi = mid - 1 # 闭区间,mid-1才能把当前这个满足条件的mid排除在外
else:
lo = mid + 1
return lo

def bisect_left(hay, needle):
lo, hi = 0, len(hay)-1
while lo <= hi: # 只有lo > hi才表示搜索空间为空
mid = lo + (hi - lo) // 2
if hay[mid] >= needle: # g(m)=True的意义是m大于等于needle
hi = mid - 1
else:
lo = mid + 1
return lo

# 如果不存在是m使g(x)成立,返回hi+1

378. Kth smallest element in a sorted matrix

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Given a n x n matrix where each of the rows and columns are sorted in ascending order,
find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

matrix = [
[ 1, 5, 9],
[10, 11, 13],
[12, 13, 15]
]
k = 8
return 13.

当时用二分查找做这个,有一个疑惑就是,最后找到的那个数为什么一定在这个二维数组里呢?

当时给自己的的解释其实也有上述视角的影子,但是总结得不够清晰易懂

因为二分计算个数时包含等于,可以认为,如果一个数满足了个数要求,还是会继续减小这个数,直到逼近不满足的边缘,而这个边缘一定是等于的情况,也就是在数组中

在一个一维数组[0,2,6,7]中,如果要查找排名第2的数是哪一个,这个数a一定要满足,大于等于数组中的两个数(即0,2),所以a是2,3,4,5中的一个,比如说现在通过0、7,得到a为3,满足要求,用0、3去得到下一个a为1,发现1不满足,2,3得到2,满足了,hi=2,跳出循环

所以真相就是,满足大于等于n个数的集合的左边界,一定是等于达成的,然后搜索寻找这个左边界,找到的数一定在数组中

所谓的左边界,就是满足g(x)=True的最小值, g(x)的意义是, x是否大于等于二维数组中的k个数。

既然是这个要求,那么计算每一个行的二分查找就自然使用bisect_right,即upper_bound

1
2
3
4
5
6
7
8
9
10
11
12
def kth(self, matrix, k):
lo, hi = matrix[0][0], matrix[-1][-1]+1
while lo < hi:
mid = (lo + hi) // 2
small_cnt = 0
for row in matrix:
small_cnt += bisect_right(row, mid)
if small_cnt >= k:
hi = mid
else:
lo = mid + 1
return lo

1011. Capacity To Ship Packages Within D Days

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
A conveyor belt has packages that must be shipped from one port to another within D days.

The i-th package on the conveyor belt has a weight of weights[i]. Each day, we load the ship with packages on the conveyor belt (in the order given by weights). We may not load more weight than the maximum weight capacity of the ship.

Return the least weight capacity of the ship that will result in all the packages on the conveyor belt being shipped within D days.

Example 1:
Input: weights = [1,2,3,4,5,6,7,8,9,10], D = 5
Output: 15
Explanation:
A ship capacity of 15 is the minimum to ship all the packages in 5 days like this:
1st day: 1, 2, 3, 4, 5
2nd day: 6, 7
3rd day: 8
4th day: 9
5th day: 10

Note that the cargo must be shipped in the order given, so using a ship of capacity 14 and splitting the packages into parts like (2, 3, 4, 5), (1, 6, 7), (8), (9), (10) is not allowed.

Note:
1 <= D <= weights.length <= 50000
1 <= weights[i] <= 500

这个也符合g(x)的要求,显然,当容量足够大,一次性运送所有货物,只要1天。所以慢慢收紧收紧容量,直到找到临界值。

因为规定按顺序装载货物,g(x)的计算自然地使用贪心。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def min_cap(weights, D):
def days_needed(cap):
k = 1
this_ship_weight = 0
for x in weights:
this_ship_weight += x
if this_ship_weight > cap:
this_ship_weight = x
k += 1
return k
lo, hi = max(weights), sum(weights)
while lo <= hi:
mid = (lo + hi) // 2
if days_needed(mid) <= D:
hi = mid - 1
else:
lo = mid + 1
return lo

005. Longest Palindromic Substring

上面两个题的g(x)都是x大于某个值,g(x)一直为True。对应的一种变化,出现在Longest palindromic substring的Rabin-Karp解法里,即,当x小于某个值,g(x)一直为True,最后lo的意义是使得g(x)为False的最小值。

004. Median of Two Sorted Arrays

设两个数组A, B各有$n_1, n_2$个数,左中位数的索引设为k,$k=\frac{n_1+_2-1}2$,右中位数的索引就是$k+1$。索引还有另一个意义——在总的有序数组C中,共有k个数小于等于索引为k的数。所以思路是,尝试取$n_1$中前$m_1$个数,取$n_2$前$m_2$个数,$m_2=k-m_1$,希望这k个数确实是C最小的k个数,即前k个数。

ZTFbqS.png

怎么判断?大数集合和小数集合没有交叉,满足:

  • A[m1] >= A[m1-1]
  • B[m2] >= B[m2-1]
  • A[m1] >= B[m2-1]
  • B[m2] >= A[m1-1]

前两个是有序数组保证的,后两个需要手动检查。

在考虑如何检查之前,已经可以观察到以m1为自变量的二分查找影子,那不妨先考虑g(x)思路,g(x)表示取出的k个数都小于等于A[m1]。显然当m1越大,g(m1)都为True。所以我们可以找到使g(x)为True的最小m1。

找出的m1,自然满足等式3:A[m1]>=B[m2-1]。可以证明,因为是最小m1,等式4也满足。利用反证法,如果B[m2] < A[m1-1],这个不等式就可以看成是最小m1为m1-1时的等式3,即最小m1可以更小,推出矛盾。

找到C的前k个数后,中位数的备选项就在A[m1]B[m2]A[m1+1]B[m2+1]四个元素中,简单的做法就是进行排序,根据奇偶,选择第一个或者前两个元素来计算中位数。

然后就是考虑边界条件了。

  1. 两个数组都为空,返回None。

  2. 考虑二分查找时的越界:

    • 二分应该发生在元素更少的数组上。反例比如,10元素和0元素的数组,二分0元素数组就会造成0元素数组的越界
    • m2-1的越界。对每一轮二分的mid,g(x)会检查A[mid]和B[k-mid-1]的大小。当k=mid时,表明C的前k个元素都由A提供,因为二分发生在元素更少的数组上,所以此时A、B元素相等,A[n1-1]、B[0]是C的左右中位数。这种情况应该列入g(x)为True,保证lo是k个数的右边界
  3. 考虑二分查找后,k寻找中位数的越界:

    也就是考虑A[m1]B[m2]A[m1+1]B[m2+1]四个索引是否越界。可以确定的B[m2]一定存在,因为即使B对k个小数没有任何贡献,A也不可能独立提供左右中位数,不然它的容量将超过B。A为空时,或者找不到g(x)为True时,A[m1]不存在。然后各自检查右边界,来判断是否存在A[m1+1]B[m2+1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def median(nums1, nums2):
n1, n2 = len(nums1), len(nums2)
if n1 > n2:
return median(nums2, nums1)
if n2 == 0:
return None
k = (n1+n2-1) // 2
lo, hi = 0, n1-1
while lo <= hi:
m1 = lo + (hi-lo) // 2
m2 = k - m1
if m2 == 0 or nums1[m1] >= nums2[m2-1]:
hi = m1 - 1
else:
lo = m1 + 1

candidates = [nums2[k-lo]]
if n1 > 0 and lo < n1:
candidates.append(nums1[lo])
if lo < n1 - 1:
candidates.append(nums1[lo+1])
if k-lo < n2 - 1:
candidates.append(nums2[k-lo+1])
candidates.sort()
if (n1+n2) % 2 == 1:
return candidates[0]
else:
return sum(candidates[:2]) * 0.5

这个版本和官方题解有区别,也AC了。不过虽然讨论这么多,也还是感觉没有讨论得足够完备。有机会在补充,可能还要再换个角度。

总结

如果使用这种思路进行二分查找,前提是能找到这样一个具有二分性质的g(x)。返回值永远是lo,表明使得g(x)状态切换的边缘值。

具体的模板,我现在更倾向使用对称的那一版,即while lo <= hi

有一些问题,比如Rotated Array的题,因为没有合适的g(x)可用,只能回归到原始的二分思想,一步步删除不包含最终解的搜索空间。