浅聊算法竞赛中维护中位数的小技巧

首先来看暑假杭电多校的一道题目:


对于一个长度为 \(L\)(\(L\)为奇数) 的数组 \(a\),定义它的中位数 \(median(a)\) 为 \(a\) 中第 \(\frac{L+1}{2}\) 大的数。现在给你一个长度为 \(n\) 的排列,对于每对满足 \(1\leq i \leq j \leq n\) 且 \(j-i \equiv 0 (mod 2)\) 的 \((i,j)\),你需要计算 \(i*j*median(p[i,j])\)。输出所有值的和。

多测数 \(T \leq 20\),排列长度 \(n \leq 2000\)。


对于这道题,首先想到的是枚举所有 \([i,j]\),通过数据结构(对顶堆等)维护中位数。由于这些方法都带log,并且本题多测数据不保证 \(n\) 的总和。因此复杂度 \(O(n^2lognT)\),无法通过本题。

换个思路,枚举区间不行,就计算每个位置的贡献。想想一个位置能成为一个区间的中位数,需要满足什么条件?该区间中大于它和小于它的数的数量相等。

一个很经典的处理就是,将小于 \(x\) 的数赋-1,将大于 \(x\) 的数赋1,我们对这个数组做前缀和,记作 \(pre\)。于是问题就转换为,在 \(x\) 的左边找到 \(pre_i\),在 \(x\) 的右边找到 \(pre_j\),使得 \(pre_i = pre_j\)。那么 \(pre_j - pre_i = 0\),这就意味着,区间 \([l+1,r]\) 中大于 \(x\) 和小于 \(x\) 的数的数量相等。

对于每个数处理一次这样的前缀和数组是 \(O(n)\) 的,因此可以在 \(O(n^2T)\) 的时间下通过本题。

Code:

cpp 复制代码
#include <iostream>
#include <vector>
using namespace std;

using ll = long long;

void solve()
{
    int n;
    cin >> n;
    vector<int> a(n+10);
    for (int i = 1 ; i <= n ; i++) cin >> a[i];

    ll ans = 0;
    for (int i = 1 ; i <= n ; i++) //枚举中位数
    {
        vector<int> pre(n+10);
        vector<int> t(2*n+10); //桶,因为会出现负数,要带n的偏移
        t[n]++;
        for (int j = 1 ; j <= n ; j++)
        {
            pre[j] = pre[j-1];
            if (a[j] > a[i]) pre[j]++;
            else if (a[j] < a[i]) pre[j]--;

            if (j < i) t[pre[j]+n] += j + 1;
            else ans += (ll)a[i] * t[pre[j]+n] * j;
        }
    }

    cout << ans << endl;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int T;
    cin >> T;
    while (T--) solve();

    return 0;
}

基于本题,可以归纳一个维护中位数的小技巧:通过给大于 \(x\) 和 小于 \(x\) 的数分别赋值,并求前缀和,就可以把求中位数问题,转化成维护前缀和之差为0。

以下再做个拓展:序列 \(a\) 的中位数 \(median(a) \ge v\),当且仅当 \(a\) 中 \(cnt_{a_i \geq v} \geq cnt_{a_i < v}\)(式中是否取等号取决于题目怎么定义中位数)。也就是说,\(median(a) \geq v\) 的充要条件是 \(a\) 中 \(\geq v\) 的数的数量多于或等于 \(< v\) 的数的数量。这个式子比较好理解,这里不做证明。

这个式子给我们提供了一种二分的思路。

看下面这道题Submedians (Easy Version)


对于长度为 \(m\) 的数组 \(b\),整数 \(v\) 是 \(b\) 的中位数,当且仅当:

  • \(v\) 至少大于等于数组中 \(\lceil \frac{m}{2} \rceil\) 个元素,并且
  • \(v\) 至少小于等于数组中 \(\lceil \frac{m}{2} \rceil\) 个元素。

现在给定一个整数 \(k\) 和一个由 \(1\) 到 \(n\) 之间的整数构成的数组 \(a_1, \ldots, a_n\)。

如果存在至少一对下标 \((l, r)\) 满足:

  • \(1 \leq l \leq r \leq n\),
  • \(r - l + 1 \geq k\),
  • \(v\) 是子数组 \([a_l, \ldots, a_r]\) 的中位数,

则称 \(1\) 到 \(n\) 之间的整数 \(v\) 是一个子中位数。

可以证明,至少存在一个子中位数。请你找出最大的子中位数 \(v_{\max}\),以及任意一组对应的下标对 \((l, r)\)。


考虑二分:当 \(v\) 越大,\(cnt_{a_i\geq v}\) 越少,反之 \(cnt_{a_i < v}\) 越大,因此越大的数越不可能作为中位数。

同样将 \(\geq v\) 的数赋1,\(< v\) 的数赋-1,做前缀和维护即可。

Code:

cpp 复制代码
#include <iostream>
#include <vector>
using namespace std;

void solve()
{
    int n,k;
    cin >> n >> k;
    vector<int> a(n+10);
    for (int i = 1 ; i <= n ; i++) cin >> a[i];

    int ansl, ansr;
    auto check = [&](int mid)
    {
        vector<int> sum(n+10);
        for (int i = 1 ; i <= n ; i++)
        {
            sum[i] = sum[i-1];
            if (a[i] >= mid) sum[i]++;
            else sum[i]--;
        }

        int minn = 2e9;
        int l;
        for (int i = k ; i <= n ; i++)
        {
            if (sum[i-k] < minn)
            {
                minn = sum[i-k];
                l = i - k + 1;
            }
            if (sum[i]-minn >= 0)
            {
                ansl = l;
                ansr = i;
                return true;
            }
        }

        return false;
    };

    int l = 0;
    int r = n + 1;
    while (l+1 != r)
    {
        int mid = (l+r) >> 1;
        if (check(mid)) l = mid;
        else r = mid;
    }

    check(l);
    cout << l << " " << ansl << " " << ansr << endl;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int T;
    cin >> T;
    while (T--) solve();

    return 0;
}

一道与上题几乎一样的题目Max Median

题目很短,这里不做题意概述。需要注意的是,由于本题定义中位数为第 \(\lfloor \frac{L+1}{2} \rfloor\) 大的数,因此本题 \(median(a) \geq v\) 的条件为:\(cnt_{a_i \geq v} > cnt_{a_i < v}\)。

Code:

cpp 复制代码
#include <iostream>
#include <vector>
using namespace std;

void solve()
{
    int n,k;
    cin >> n >> k;
    vector<int> a(n+10);
    for (int i = 1 ; i <= n ; i++) cin >> a[i];

    int ansl, ansr;
    auto check = [&](int mid)
    {
        vector<int> sum(n+10);
        for (int i = 1 ; i <= n ; i++)
        {
            sum[i] = sum[i-1];
            if (a[i] >= mid) sum[i]++;
            else sum[i]--;
        }

        int minn = 2e9;
        int l;
        for (int i = k ; i <= n ; i++)
        {
            if (sum[i-k] < minn)
            {
                minn = sum[i-k];
                l = i - k + 1;
            }
            if (sum[i]-minn > 0)
            {
                ansl = l;
                ansr = i;
                return true;
            }
        }

        return false;
    };

    int l = 0;
    int r = n + 1;
    while (l+1 != r)
    {
        int mid = (l+r) >> 1;
        if (check(mid)) l = mid;
        else r = mid;
    }

    check(l);
    cout << l << endl;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int T = 1;
    while (T--) solve();

    return 0;
}