重新学算法(1)——分治

现在很后悔本科算法课学得太水了,在这一年里重新学一遍补一补课,希望能有新的收获。

分治策略

分治(Divide and Conquer)策略,适用于可以将问题分解为规模较小但形式相同的子问题的情况,与递归联系紧密。

分治法在每一层递归中主要是三个步骤:

  1. 分解(Divide).将问题划分为若干字问题,形式与原问题相同,但是规模更小
  2. 解决(Conquer).子问题规模很小,直接求解,结束递归。否则递归求解
  3. 合并(Combine).将子问题的解合并成为原问题的最终解

有时问题分解后只需要解决一个子问题,如折半查找。这种算法可以写成尾递归或者循环的形式,这时候合并步骤就不需要了,这类问题又被特别称为减治(Decrease and Conquer)。我个人觉得形式上是一样的,只不过是分治的一种特殊情况。

主定理

分治法的时间复杂度分析通常采用递归式(recurrence)表达,为了求解递归式,常用的方法有递归树法和主方法。递归树法就是展开成一棵递归树,然后求和计算。主方法就是采用主定理(master theorem)来求解特定形式的递归式。

内容

给定如下形式的递归式: $$ T(n) = aT(\frac{n}{b})+f(n) $$

分成以下三种情况讨论:

  1. 如果存在 $\epsilon > 0$ 使得 $ f(n) = O(n^{\log_b{a}-\epsilon}) $, 那么 $ T(n) = \Theta(n^{\log_b{a}})$
  2. 如果存在 $ k \geq 0$ 使得 $ f(n) = \Theta(n^{\log_b{a}}\log_k{n}) $, 那么 $ T(n) = \Theta(n^{\log_b{a}}\log_{k+1}{n}) $
  3. 如果存在 $\epsilon > 0$ 有 $ f(n) = \Omega(n^{\log_b{a}+\epsilon}) $, 同时存在 $ c<1$ 和充分大的$n$, 满足 $af(\frac{n}{b}) \leq cf(n)$ 那么 $ T(n) = \Theta(f(n))$

形式看起来有些不太直观。对于我这种初学者来说,$ f(n) $ 一般都是多项式函数 $ O(n^d)$,此时只要考虑指数 $ d $ 与 $ log_b{a} $ 的大小关系即可,可以用来快速算一些算法的复杂度。

可以注意到主定理的形式只覆盖了一部分特定的递归式,如果想要求解任何一个递归式的值,可采用 Akra-Bazzi 方法.

证明

对于后面说的 $ f(n) $ 是多项式的形式,证明较容易。只要每次将递归项展开,得到一个等比数列求和,最后分情况来讨论即可。前面那个标准的形式看起来比较数学,我就不再继续深入研究了(其实是我数学水平不行)。

应用

如归并排序的递归式 $ T(n) = 2T(\frac{n}{2}) + O(n) $, $d=\log_b{a}=1$, 求解出复杂度为 $ O(n\log{n}) $

如折半查找的递归式 $ T(n) = T(\frac{n}{2}) + O(1) $, $ d=\log_b{a}=0$, 求解出复杂度为 $ O(\log{n}) $

经典实例

里面出现的我写的代码都是左闭右开区间的,为了这个区间的开闭纠结了好久,搞得我焦头烂额,折半查找都不会写了。最后决定以后代码都写左闭右开,这样是跟 STL 标准库还有 python 的 range() 保持一致的。

归并排序

归并排序(merge sort) 可以说是一个最经典的分治法的案例了,就连本科的计算机导论课上都会拿来做例子用来介绍算法,貌似白板手写归并也经常在面试中出现。

主要思路就是每次都将要排序的数组分成两个子数组,递归下去直到分成不可再分。利用 merge 过程将两个有序的子数组合并,最后合并成一个完整的排序后数组。

int tmp[LEN];
void merge(int arr[], int l, int mid, int r)
{
    int i = l, j = mid, k = l;
    while (i < mid && j < r) {
        tmp[k++] = arr[i] < arr[j]? arr[i++]:arr[j++];
    }
    while (i < mid) {
        tmp[k++] = arr[i++];
    }
    while (j < r) {
        tmp[k++] = arr[j++];
    }
    for (k = l; k < r; k++) {
        arr[k] = tmp[k];
    }
}
void merge_sort(int arr[], int l, int r)
{
    if (l >= r - 1) return ;
    int mid = l + (r - l) / 2;
    merge_sort(arr, l, mid);
    merge_sort(arr, mid, r);
    merge(arr, l, mid, r);
}

递归写最容易理解,其实也可以自底向上迭代来写,设置一个间隔,每次迭代将相邻的两个段合并。

快速排序

快排也是计算机本科生见得最多的算法之一了。主要思想就是每次选出一个中轴(pivot), 将小于 pivot 的元素放到左边,大于 pivot 的元素放到右边,然后对左右两个子数组进行递归。

void quick_sort(int arr[], int lo, int hi)
{
    if (lo >= hi - 1) return;
    int pos = partition(arr, lo, hi);
    quick_sort(arr, lo, pos);
    quick_sort(arr, pos + 1, hi);
}

快排中比较重要的是 partition 函数,比较常见的有两种写法: Lomuto’s 和 Hoare’s. 具体到实际的细节比如哪里 pivot 存在头部还是尾部、如何选择 pivot 等等也是有很多种实现方式,可以分析和优化的点还有很多。

Lomuto partition 的一种写法:

int partition(int arr[], int lo, int hi)
{
    int pivot = arr[hi-1];
    int i = lo;
    for (int j = lo; j < hi - 1; j++) {
        if (arr[j] < pivot) {
            std::swap(arr[i], arr[j]);
            i++;
        }
    }
    std::swap(arr[i], arr[hi - 1]);
    return i;
}

Hoare partition 的一种写法:

int partition(int arr[], int lo, int hi)
{
    int pivot = arr[lo];
    int i = lo - 1;
    int j = hi;
    while (true) {
        do {
            j--;
        } while (arr[j] > pivot);
        do {
            i++;
        } while (arr[i] < pivot);
        if (i < j) {
            std::swap(arr[i], arr[j]);
        } else {
            return j;
        }
    }
}
void qsort(int arr[], int l, int r)
{
    if (l >= r - 1)
        return;
    int pos = partition2(arr, l, r);
    qsort(arr, l, pos + 1);
    qsort(arr, pos + 1, r);
}

注意后者的划分函数返回的并不是 pivot 最终的下标,递归的时候那里跟前者是不一样的,我之前从未自己写对过这个算法哈哈

前者代码更加简洁,后者似乎在$O(n)$的常数上更有优势。

本科课本上的算法也是双向扫描,不过跟 Hoare 不一样,每次不是交换两个而是交替覆盖, 我一般习惯写这个,印象比较深刻。

int partition(int arr[], int lo, int hi)
{
    int pivot = arr[lo];
    int i = lo, j = hi - 1;
    while (i < j) {
        while (i < j && arr[j] >= pivot)
            j--;
        arr[i] = arr[j];
        while (i < j && arr[i] <= pivot)
            i++;
        arr[j] = arr[i];
    }
    arr[i] = pivot;
    return i;
}
// 递归同 lomuto

当然,如果只考虑写得爽的话在很多语言中可以一句话完事:

qsort [] = []
qsort (x:xs) = qsort [a|a<-xs,a<=x] ++ [x] ++ qsort [a|a<-xs,a>x]

快速选择

这个问题的目的是从一个数组中找出第 $k$ 大的元素. 最直接的办法就是排序,但是可以发现排序中很多操作是不必要的,因此复杂度上做不到最优。快速选择将快排的思路用到这里,可以达到 $ O(n) $.

具体思路利用到快排中的 partition, 分完之后可以根据 pivot 的位置判断 k 在 pivot 的前半段还是在后半段,递归下去,从而减而治之。

int select_kth(int arr[], int l, int r, int k)
{
    int pos = partition(arr[], l, r);
    if (pos == k - 1) {
        return arr[pos];
    } else if (pos > k - 1) { // k_th in left part
        return select_kth(arr, l, pos, k);
    } else { // kth in right part
        return select_kth(arr, pos + 1, r, k);
    }
}

数逆序数

这里逆序数的定义就是线代里那个。笨的方法就是 $O(n^2)$ 双重循环找,使用分治可以降低复杂度。

具体做法就是给这个数组做归并排序,在归并的过程中,数出跨两个有序数组的逆序数,然后递归加上排序前两边数组内部各自的逆序数,得到最终的结果。

// merge 时发现左边更大就计数(mid-i)
int count_inversions(int arr[], int l, int r)
{
    if (l >= r - 1) return 0;
    int mid = l + (r - l) / 2;
    int cl = count_inversions(arr, l, mid);
    int cr = count_inversions(arr, mid, r);
    int c = merge(arr, l, mid, r);
    return cl + cr + c;
}

最近点对

平面上有若干个点,现在要计算哪一对点离得最近。最 naive 的方法就是遍历每一对,遍历时会发现有很多点明显更远还是被计算了,用分治法可以更快。

将点按照 $x$ 坐标排序,每次将所有点分成两部分,递归计算左边和右边的最近距离。剩下的就只有横跨分界线的点对了,此时只需要考虑距离小于之前求出的最短距离的点对。

接下来要用到一些技巧。设之前的最小距离为 $\delta$, 此时只需要考虑分界线左右两边各 $\delta$ 长度的格子里面的点。垂直方向上也分成长度为 $\frac{\delta}{2}$ 的格子,则要找的符合条件的点不会超过两行。这样就将寻找的范围大大缩小了。

如果将符合条件的点按照 y 轴进行排序,那么只需要考虑 y 相邻的点之间的距离了。为了减少排序的开销,可以借鉴归并排序的思想,每次递归时将两边的点按照 y 轴坐标大小进行归并,这样每次递归的开销就只有 $ O(n) $ 了。

递归式 $ T(n) = 2T(\frac{n}{2}) + O(n) $, 最终开销只有 $ O(n\log{n}) $

#include <algorithm>
#include <cmath>
#include <cstdio>
struct Point {
    double x;
    double y;
};
Point ps[100005];
Point tmp[100005];
bool cmp_x(const Point& a, const Point& b) { return a.x < b.x; }
double dis(const Point& a, const Point& b)
{
    double dx = (a.x - b.x) * (a.x - b.x);
    double dy = (a.y - b.y) * (a.y - b.y);
    return sqrt(dx + dy);
}
void merge(int low, int high)
{
    int idx = low;
    int mid = (low + high) / 2;
    int i = low, j = mid;
    while (i < mid && j < high) {
        if (ps[i].y < ps[j].y) {
            tmp[idx++] = ps[i++];
        } else {
            tmp[idx++] = ps[j++];
        }
    }
    while (i < mid) {
        tmp[idx++] = ps[i++];
    }
    while (j < high) {
        tmp[idx++] = ps[j++];
    }
    for (int i = low; i < high; i++) {
        ps[i] = tmp[i];
    }
}
double closest_dis(int low, int high)
{
    if (high - 1 <= low) { // invalid
        return 1e20;
    }
    if (high - 2 == low) { // only two points
        merge(low, high);
        return dis(ps[low], ps[low + 1]);
    }
    int mid = (low + high) / 2;

    // recurse
    double left = closest_dis(low, mid);
    double right = closest_dis(mid, high);
    double min_dis = std::min(left, right);

    merge(low, high);
    // find points with [mid.x - delta, mid.x + delta]
    int cnt = 0;
    for (int i = low; i < high; i++) {
        if (ps[i].x >= ps[mid].x - min_dis && ps[i].x <= ps[mid].x + min_dis) {
            tmp[cnt++] = ps[i];
        }
    }
    for (int i = 0; i < cnt; i++) {
        for (int j = i + 1; j < cnt && tmp[j].y <= tmp[i].y + min_dis; j++) {
            double d = dis(tmp[i], tmp[j]);
            min_dis = std::min(min_dis, d);
        }
    }
    return min_dis;
}
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf("%lf %lf", &ps[i].x, &ps[i].y);
    }
    std::sort(ps, ps + n, cmp_x); // sort by x axis
    double ans = closest_dis(0, n);
    printf("%.2lf\n", ans);
    return 0;
}

快速傅里叶变换

待续