现在很后悔本科算法课学得太水了,在这一年里重新学一遍补一补课,希望能有新的收获。
分治策略
分治(Divide and Conquer)策略,适用于可以将问题分解为规模较小但形式相同的子问题的情况,与递归联系紧密。
分治法在每一层递归中主要是三个步骤:
- 分解(Divide).将问题划分为若干字问题,形式与原问题相同,但是规模更小
- 解决(Conquer).子问题规模很小,直接求解,结束递归。否则递归求解
- 合并(Combine).将子问题的解合并成为原问题的最终解
有时问题分解后只需要解决一个子问题,如折半查找。这种算法可以写成尾递归或者循环的形式,这时候合并步骤就不需要了,这类问题又被特别称为减治(Decrease and Conquer)。我个人觉得形式上是一样的,只不过是分治的一种特殊情况。
主定理
分治法的时间复杂度分析通常采用递归式(recurrence)表达,为了求解递归式,常用的方法有递归树法和主方法。递归树法就是展开成一棵递归树,然后求和计算。主方法就是采用主定理(master theorem)来求解特定形式的递归式。
内容
给定如下形式的递归式: $$ T(n) = aT(\frac{n}{b})+f(n) $$
分成以下三种情况讨论:
- 如果存在 $\epsilon > 0$ 使得 $ f(n) = O(n^{\log_b{a}-\epsilon}) $, 那么 $ T(n) = \Theta(n^{\log_b{a}})$
- 如果存在 $ k \geq 0$ 使得 $ f(n) = \Theta(n^{\log_b{a}}\log_k{n}) $, 那么 $ T(n) = \Theta(n^{\log_b{a}}\log_{k+1}{n}) $
- 如果存在 $\epsilon > 0$ 有 $ f(n) = \Omega(n^{\log_b{a}+\epsilon}) $, 同时存在 $ c<1$ 和充分大的$n$, 满足 $af(\frac{n}{b}) \leq cf(n)$ 那么 $ T(n) = \Theta(f(n))$
形式看起来有些不太直观。对于我这种初学者来说,$ f(n) $ 一般都是多项式函数 $ O(n^d)$,此时只要考虑指数 $ d $ 与 $ log_b{a} $ 的大小关系即可,可以用来快速算一些算法的复杂度。
可以注意到主定理的形式只覆盖了一部分特定的递归式,如果想要求解任何一个递归式的值,可采用 Akra-Bazzi 方法.
证明
对于后面说的 $ f(n) $ 是多项式的形式,证明较容易。只要每次将递归项展开,得到一个等比数列求和,最后分情况来讨论即可。前面那个标准的形式看起来比较数学,我就不再继续深入研究了(其实是我数学水平不行)。
应用
如归并排序的递归式 $ T(n) = 2T(\frac{n}{2}) + O(n) $, $d=\log_b{a}=1$, 求解出复杂度为 $ O(n\log{n}) $
如折半查找的递归式 $ T(n) = T(\frac{n}{2}) + O(1) $, $ d=\log_b{a}=0$, 求解出复杂度为 $ O(\log{n}) $
经典实例
里面出现的我写的代码都是左闭右开区间的,为了这个区间的开闭纠结了好久,搞得我焦头烂额,折半查找都不会写了。最后决定以后代码都写左闭右开,这样是跟 STL 标准库还有 python 的 range() 保持一致的。
归并排序
归并排序(merge sort) 可以说是一个最经典的分治法的案例了,就连本科的计算机导论课上都会拿来做例子用来介绍算法,貌似白板手写归并也经常在面试中出现。
主要思路就是每次都将要排序的数组分成两个子数组,递归下去直到分成不可再分。利用 merge 过程将两个有序的子数组合并,最后合并成一个完整的排序后数组。
int tmp[LEN];
void merge(int arr[], int l, int mid, int r)
{
int i = l, j = mid, k = l;
while (i < mid && j < r) {
tmp[k++] = arr[i] < arr[j]? arr[i++]:arr[j++];
}
while (i < mid) {
tmp[k++] = arr[i++];
}
while (j < r) {
tmp[k++] = arr[j++];
}
for (k = l; k < r; k++) {
arr[k] = tmp[k];
}
}
void merge_sort(int arr[], int l, int r)
{
if (l >= r - 1) return ;
int mid = l + (r - l) / 2;
merge_sort(arr, l, mid);
merge_sort(arr, mid, r);
merge(arr, l, mid, r);
}
递归写最容易理解,其实也可以自底向上迭代来写,设置一个间隔,每次迭代将相邻的两个段合并。
快速排序
快排也是计算机本科生见得最多的算法之一了。主要思想就是每次选出一个中轴(pivot), 将小于 pivot 的元素放到左边,大于 pivot 的元素放到右边,然后对左右两个子数组进行递归。
void quick_sort(int arr[], int lo, int hi)
{
if (lo >= hi - 1) return;
int pos = partition(arr, lo, hi);
quick_sort(arr, lo, pos);
quick_sort(arr, pos + 1, hi);
}
快排中比较重要的是 partition 函数,比较常见的有两种写法: Lomuto’s 和 Hoare’s. 具体到实际的细节比如哪里 pivot 存在头部还是尾部、如何选择 pivot 等等也是有很多种实现方式,可以分析和优化的点还有很多。
Lomuto partition 的一种写法:
int partition(int arr[], int lo, int hi)
{
int pivot = arr[hi-1];
int i = lo;
for (int j = lo; j < hi - 1; j++) {
if (arr[j] < pivot) {
std::swap(arr[i], arr[j]);
i++;
}
}
std::swap(arr[i], arr[hi - 1]);
return i;
}
Hoare partition 的一种写法:
int partition(int arr[], int lo, int hi)
{
int pivot = arr[lo];
int i = lo - 1;
int j = hi;
while (true) {
do {
j--;
} while (arr[j] > pivot);
do {
i++;
} while (arr[i] < pivot);
if (i < j) {
std::swap(arr[i], arr[j]);
} else {
return j;
}
}
}
void qsort(int arr[], int l, int r)
{
if (l >= r - 1)
return;
int pos = partition2(arr, l, r);
qsort(arr, l, pos + 1);
qsort(arr, pos + 1, r);
}
注意后者的划分函数返回的并不是 pivot 最终的下标,递归的时候那里跟前者是不一样的,我之前从未自己写对过这个算法哈哈
前者代码更加简洁,后者似乎在$O(n)$的常数上更有优势。
本科课本上的算法也是双向扫描,不过跟 Hoare 不一样,每次不是交换两个而是交替覆盖, 我一般习惯写这个,印象比较深刻。
int partition(int arr[], int lo, int hi)
{
int pivot = arr[lo];
int i = lo, j = hi - 1;
while (i < j) {
while (i < j && arr[j] >= pivot)
j--;
arr[i] = arr[j];
while (i < j && arr[i] <= pivot)
i++;
arr[j] = arr[i];
}
arr[i] = pivot;
return i;
}
// 递归同 lomuto
当然,如果只考虑写得爽的话在很多语言中可以一句话完事:
qsort [] = []
qsort (x:xs) = qsort [a|a<-xs,a<=x] ++ [x] ++ qsort [a|a<-xs,a>x]
快速选择
这个问题的目的是从一个数组中找出第 $k$ 大的元素. 最直接的办法就是排序,但是可以发现排序中很多操作是不必要的,因此复杂度上做不到最优。快速选择将快排的思路用到这里,可以达到 $ O(n) $.
具体思路利用到快排中的 partition, 分完之后可以根据 pivot 的位置判断 k 在 pivot 的前半段还是在后半段,递归下去,从而减而治之。
int select_kth(int arr[], int l, int r, int k)
{
int pos = partition(arr[], l, r);
if (pos == k - 1) {
return arr[pos];
} else if (pos > k - 1) { // k_th in left part
return select_kth(arr, l, pos, k);
} else { // kth in right part
return select_kth(arr, pos + 1, r, k);
}
}
数逆序数
这里逆序数的定义就是线代里那个。笨的方法就是 $O(n^2)$ 双重循环找,使用分治可以降低复杂度。
具体做法就是给这个数组做归并排序,在归并的过程中,数出跨两个有序数组的逆序数,然后递归加上排序前两边数组内部各自的逆序数,得到最终的结果。
// merge 时发现左边更大就计数(mid-i)
int count_inversions(int arr[], int l, int r)
{
if (l >= r - 1) return 0;
int mid = l + (r - l) / 2;
int cl = count_inversions(arr, l, mid);
int cr = count_inversions(arr, mid, r);
int c = merge(arr, l, mid, r);
return cl + cr + c;
}
最近点对
平面上有若干个点,现在要计算哪一对点离得最近。最 naive 的方法就是遍历每一对,遍历时会发现有很多点明显更远还是被计算了,用分治法可以更快。
将点按照 $x$ 坐标排序,每次将所有点分成两部分,递归计算左边和右边的最近距离。剩下的就只有横跨分界线的点对了,此时只需要考虑距离小于之前求出的最短距离的点对。
接下来要用到一些技巧。设之前的最小距离为 $\delta$, 此时只需要考虑分界线左右两边各 $\delta$ 长度的格子里面的点。垂直方向上也分成长度为 $\frac{\delta}{2}$ 的格子,则要找的符合条件的点不会超过两行。这样就将寻找的范围大大缩小了。
如果将符合条件的点按照 y 轴进行排序,那么只需要考虑 y 相邻的点之间的距离了。为了减少排序的开销,可以借鉴归并排序的思想,每次递归时将两边的点按照 y 轴坐标大小进行归并,这样每次递归的开销就只有 $ O(n) $ 了。
递归式 $ T(n) = 2T(\frac{n}{2}) + O(n) $, 最终开销只有 $ O(n\log{n}) $
#include <algorithm>
#include <cmath>
#include <cstdio>
struct Point {
double x;
double y;
};
Point ps[100005];
Point tmp[100005];
bool cmp_x(const Point& a, const Point& b) { return a.x < b.x; }
double dis(const Point& a, const Point& b)
{
double dx = (a.x - b.x) * (a.x - b.x);
double dy = (a.y - b.y) * (a.y - b.y);
return sqrt(dx + dy);
}
void merge(int low, int high)
{
int idx = low;
int mid = (low + high) / 2;
int i = low, j = mid;
while (i < mid && j < high) {
if (ps[i].y < ps[j].y) {
tmp[idx++] = ps[i++];
} else {
tmp[idx++] = ps[j++];
}
}
while (i < mid) {
tmp[idx++] = ps[i++];
}
while (j < high) {
tmp[idx++] = ps[j++];
}
for (int i = low; i < high; i++) {
ps[i] = tmp[i];
}
}
double closest_dis(int low, int high)
{
if (high - 1 <= low) { // invalid
return 1e20;
}
if (high - 2 == low) { // only two points
merge(low, high);
return dis(ps[low], ps[low + 1]);
}
int mid = (low + high) / 2;
// recurse
double left = closest_dis(low, mid);
double right = closest_dis(mid, high);
double min_dis = std::min(left, right);
merge(low, high);
// find points with [mid.x - delta, mid.x + delta]
int cnt = 0;
for (int i = low; i < high; i++) {
if (ps[i].x >= ps[mid].x - min_dis && ps[i].x <= ps[mid].x + min_dis) {
tmp[cnt++] = ps[i];
}
}
for (int i = 0; i < cnt; i++) {
for (int j = i + 1; j < cnt && tmp[j].y <= tmp[i].y + min_dis; j++) {
double d = dis(tmp[i], tmp[j]);
min_dis = std::min(min_dis, d);
}
}
return min_dis;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%lf %lf", &ps[i].x, &ps[i].y);
}
std::sort(ps, ps + n, cmp_x); // sort by x axis
double ans = closest_dis(0, n);
printf("%.2lf\n", ans);
return 0;
}
快速傅里叶变换
待续