0%

找到两个有序数组的中位数

这是 LeetCode 的第 4 题@linjie 问到我说,这题网上的解答看完都迷迷糊糊的,希望我能写篇博客重新讲一下。于是有此篇。

题目

现有不同时为空的非降序排列的数组 AB,其长度分别是 mn。不失一般性,假设 m <= n。记 AB 的中位数为 x。要求写代码实现,在 $\Theta(\log(m + n))$ 的时间复杂度内找到 x

分析

见到时间复杂度要求中有 $\log$,就肯定要想到二分搜索。

二分搜索的本质是在一个有序 randomly accessible array 中,寻找一个满足某种跟元素顺序相关的条件的元素的方法。这个「跟元素顺序相关的条件」,在最初的版本里是「等于某个数值」,扩展版本可以是「第一个不小于给定数值的值」之类的(参考前作)。展开来说,满足以下 4 个要素时,可以用到二分搜索:

  • array 只有 1 个,或者可以将多个 array 问题简化成 1 个 array 的问题;
  • array 是 randomly accessible 的;
  • array 是有序的;
  • 搜索 array 中的元素时,限制条件是跟元素顺序有关的。

因此,为了利用二分搜索,我们必须要想办法:

  1. 变两个序列的问题为 1 个序列的问题;
  2. 找到某个「跟元素顺序相关的条件」。

首先考虑 (1)。

iA 中的下标,而 jB 中的下标:A[i]A 中第一个大于等于中位数 x 的元素;同时 B[j]B 中第一个大于等于中位数 x 的元素。当 A 中元素全部小于 x 时,i = m;同理,当 B 中元素全部小于 x 时,j = n

这意味着,A 中小于 x 的元素数量为 iB 中小于 x 的元素数量为 jA 中大于等于 x 的元素数量为 m - iB 中大于等于 x 的元素数量为 n - j。满足条件:

  • m + n 是偶数:i + j = (m - i) + (n - j),此时 j = (m + n) / 2 - i
  • m + n 是奇数:i + j = (m - i) + (n - j) - 1,此时 j = (m + n - 1) / 2 - i = (m + n) / 2 - i(考虑整数除法除不尽时向零取整)。

所以 j 有统一的表达式 (m + n) / 2 - i。这样一来,我们就建立了两个序列下标之间的对应关系,从而将两个序列的问题变为了 1 个序列的问题。

接着考虑 (2)。

考虑 x 是中位数,而 A[i]B[j] 分别是 AB 两个序列中第一个大于等于中位数 x 的元素。因此有:

  1. A[i - 1] < x <= A[i]
  2. A[i - 1] < x <= B[j]
  3. B[j - 1] < x <= A[i]
  4. B[j - 1] < x <= B[j]

其中 (1) 和 (4) 是 trivial 的。考虑 (2) 和 (3),即得到目标「条件」:(i == 0 or A[i - 1] < x <= B[j]) and (i == m or B[j - 1] < x <= A[i])

因此有伪代码:

1
2
3
4
5
6
7
8
9
10
for i in range(0, m + 1):
bsearch to find i, s.t.:
(i == 0 or A[i - 1] < x <= B[j]) and (i == m or B[j - 1] < x <= A[i])

if (m + n) % 2 == 0:
# carefully handle index out-of-bound
return static_cast<double>(max(A[i - 1], B[j - 1]) + min(A[i], B[j])) / 2.0
else:
# carefully handle index out-of-bound
return min(A[i], B[j])

算法的时间复杂度是 $\Theta(\log(\min(m, n)))$,空间复杂度是 $\Theta(1)$

C++ 实现

给一个 C++ 版本的完整实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
static const auto io_sync_off = []() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
return nullptr;
} ();

class Solution {
public:
double findMedianSortedArrays(const std::vector<int>& A, const std::vector<int>& B) {
if (A.empty()) {
return findMedianInSortedArray(B);
} else if (B.empty()) {
return findMedianInSortedArray(A);
} else {
return bsearchWrapper(A, B);
}
}

private:
inline double findMedianInSortedArray(const std::vector<int>& v) {
size_t len = v.size();
if (len % 2 == 0) {
return static_cast<double>(v[len / 2] + v[len / 2 - 1]) / 2.0;
} else {
return v[len / 2];
}
}

inline double bsearchWrapper(const std::vector<int>& A, const std::vector<int>& B) {
const size_t m = A.size(), n = B.size();
if (m <= n) {
return bsearchHelper(A, B);
} else {
return bsearchHelper(B, A);
}
}

inline double bsearchHelper(const std::vector<int>& A, const std::vector<int>& B) {
const size_t m = A.size(), n = B.size();
size_t left = 0, right = m + 1;
size_t i, j;
while (left < right) {
i = left + (right - left) / 2;
j = (m + n) / 2 - i;
if (not(i == m or B[j - 1] < A[i])) {
left = i + 1;
} else if (not(i == 0 or A[i - 1] < B[j])) {
right = i;
} else {
break;
}
}
if ((m + n) % 2 == 0) {
double mx = (i == m) ? B[j] : ((j == n) ? A[i] : std::min(A[i], B[j]));
double mn = (i == 0) ? B[j - 1] : ((j == 0) ? A[i - 1] : std::max(A[i - 1], B[j - 1]));
return (mn + mx) / 2.0;
} else {
return (i == m) ? B[j] : ((j == n) ? A[i] : std::min(A[i], B[j]));
}
}
};
俗话说,投资效率是最好的投资。 如果您感觉我的文章质量不错,读后收获很大,预计能为您提高 10% 的工作效率,不妨小额捐助我一下,让我有动力继续写出更多好文章。