算法学习笔记(3.1): ST算法

ST表

在RMQ(区间最值)问题中,著名的ST算法就是倍增的产物。ST算法可以在 \(O(n \log n)\) 的时间复杂度能预处理后,以 \(O(1)\) 的复杂度在线回答 区间 [l, r] 内的最值。

当然,ST表不支持动态修改,如果需要动态修改,线段树 是一种良好的解决方案,是 \(O(n)\) 的预处理时间复杂度,但是查询需要 \(O(\log n)\) 的时间复杂度

那么ST表中倍增的思想是如何体现的呢?

一个序列的子区间明显有 \(n^2\) 个,根据倍增的思想,我们在这么多个子区间中选择一些长度为 \(2\) 的整数次幂的区间作为代表值。

设 \(st[i][j]\) 表示子区间 \([i, i+2^j)\) 里最大的数

也可以表示为 \([i, i + 2^j -1 ]\),无论如何,其中有 \(2^j\) 个元素
下文中的 \(a\) 表示原序列

递推边界明显是 \(st[i][0] = a[i]\)。

于是,根据成倍增长的长度,有了递推公式

\[st[i][j] = max(st[i][j-1],\;st[i+2^{j-1}][j-1]) \]

当询问任意区间 \([l, r]\) 的最值时,我们先计算出一个最大的 \(k\) 满足:\(2^k \le r - l + 1\),即需要不大于区间长度。那么,由于二进制划分我们可以知道,这个最大的 k 一定满足 \(2^{k+1}\ge r-l+1\),即我们只需要将两个长度为 \(2^k\) 的区间合并即可。

又根据 max(a, a) = a 可以知道,重复计算区间是没有任何问题的。

所以,在寻找最值的时候就有了以下公式:

\[max(a[l, r]) = max(st[l][k], st[r-2^k + 1][k]) \]

那么这里给出一种参考代码

cpp 复制代码
// 啊,写这种预处理以2位底的对数的整数值的方式
// 我主要是为了将代码模块化,做到低耦合度
// 完全是可以分开来写的
class Log2Factory {
private:
    int lg2[N];
public:
    void init(int n) {
        for (int i = 2; i <= n; ++i) lg2[i] = lg2[i >> 1] + 1;
    }

    // 重载()运算符
    int operator() (const int &i) {
        return lg2[i];
    }
};

template<typename T>
class STable {
private:
    typedef T(*OP_FUNC)(T, T);

    Log2Factory Log2;
    T f[N][17]; // maybe most of the times k=17 is ok, make sure 2^k greater than N;
    OP_FUNC op;
public:
    void setOp(OP_FUNC fc) {
        op = fc;
    }

    void init(T *a, int n) {
        for (int i = 1; i <= n; ++i)
            f[i][0] = *(++a);

        int t = Log2(n);
        // f[i][k] is the interval of [i, i + 2^k - 1]
        // so f[i][k] can equal to the op sum of [i, i^k - 1]
        // let r = i^k - 1
        // => f[r - (1^k) + 1][k] can equal to the op sum of [i][k]
        for (int k = 1; k <= t; ++k) {
            for (int i = 1; i + (1<<k) - 1 <= n; ++i)
                f[i][k] = op(f[i][k-1], f[i + (1<<(k-1))][k-1]);
        }
    }

    const T query(int l, int r) {
        int k = Log2(r - l + 1);
        return op(f[l][k], f[r - (1<<k) + 1][k]);
    }
};

这......写法很神奇,注意修改!

扩展 - 运算

ST 算法不仅仅是可以求区间的最值的,只要是满足静态 的,满足区间加法的问题大多数情况都可以通过 ST 表实现。

那么区间加法是什么意思呢?

定义我们需要对数列的筛选函数为 op ,则需要 op 满足以下性质

  • op(a, a) = a ,即重复参与运算不改变最终影响

  • op(a, b) = op(b, a) ,即满足交换律

  • op(a, op(b, c)) = op(op(a, b), c) ,即满足结合律

举个例子,如果我们求区间是否有负数,可以将 op 设为如下逻辑:

c 复制代码
bool op(bool a, bool b) {
    return a | b;
}

相应的,初始化的方式也需要更改

c 复制代码
if (a[i] < 0) st[i][0] = true;
else st[i][0] = false;

再举一个例子,如果我们需要求区间是否全为偶数时,则初始化为

c 复制代码
if (a[i] % 2 == 0) st[i][0] = true;
else st[i][0] = false;

操作 op 定义为

c 复制代码
bool op(bool a, bool b) {
    return a & b;
}

由此可见,其实ST算法可以做到的不仅仅是区间最值那么普通的东西啊。

但是,由于 加法 不满足性质一,所以,ST表通过这种方法并不能求得区间的所有满足某种性质的元素的个数。但是,通过另外一种 query 方式,我们可以做到这样。

扩展 - 区间

那么这个部分我们将讨论如何利用ST表做到上文例子中求区间偶数的个数。

同样,由于我们可以通过二进制划分,所以可以将某一个区间长度转化为多个长度为2的整数幂次方的子区间,并且可以保证这些区间不相互重叠。

所以我们可以利用这个处理 op(a, a) != a 的情况了。

其实这是借鉴了一点线段树的思路

还不如直接用线段树......

那么可以写出以下代码

c 复制代码
int query(int l, int r) {
    if (l == r) return st[l][0];
    int k = log2(r - l + 1);
    return op(st[l][k], query(l + (1<<k), r))
}

这样就满足了区间不重叠

或许会有一个问题,为什么初始化的时候不需要修改?

其实不难发现,初始化的合并是不会有重复贡献的情况的,即是每一次合并的区间是不会重叠的