区间不同数的个数-树状数组/线段树/莫队/主席树

目录

引言

本问题将使用如下数据结构

数据结构 时间复杂度 实现难度
树状数组(Fenwick-Tree) O ( m log ⁡ n ) O(m \log n) O(mlogn) ⭐ (简单)
线段树(Segment-Tree) O ( m log ⁡ n ) O(m \log n) O(mlogn) ⭐⭐ (中等)
莫队(Mo) O ( n n ) O(n \sqrt n) O(nn ) ⭐⭐ (中等)
主席树(可持久化线段树) O ( ( n + m ) log ⁡ n ) O((n + m)\log n) O((n+m)logn) ⭐⭐⭐ (困难)

题目-HH的项链

离线-树状数组解法

树状数组维护当前数组出现的位置信息, 具体的来说

  • 为了保证时间复杂度, 要求指针 r r r一直单调
  • 对于当前遍历到的数字 w w w, 如果出现过那么将上一次出现的位置 − 1 -1 −1
  • 然后将当前位置记录 l a s t i = j last_i = j lasti=j
  • 然后 j j j枚举下一个位置

因为 j j j指针只会走一次, 算法时间复杂度 O ( m log ⁡ n ) O(m \log n) O(mlogn)

核心代码

cpp 复制代码
    sort(q + 1, q + m + 1);

    int j = 1;
    for (int i = 1; i <= m; ++i) {
        int l = q[i].l, r = q[i].r, id = q[i].id;
        while (j <= r) {
            if (last[w[j]]) add(last[w[j]], -1);
            add(j, 1);
            last[w[j]] = j;
            j++;
        }
        ans[id] = get(r) - get(l - 1);
    }

代码实现

cpp 复制代码
#include <bits/stdc++.h>

using namespace std;

const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;

int n, m;
int w[N];
struct Ask {
    int l, r, id;

    bool operator< (const Ask &a) const {
        return r < a.r;
    }
} q[M];
int tr[S], last[S], ans[M];

inline int lowbit(int x) {
    return x & -x;
}

void add(int u, int x) {
    for (int i = u; i < S; i += lowbit(i)) tr[i] += x;
}

int get(int u) {
    int ans = 0;
    for (int i = u; i; i -= lowbit(i)) ans += tr[i];
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n;
    for (int i = 1; i <= n; ++i) cin >> w[i];

    cin >> m;
    for (int i = 1; i <= m; ++i) {
        int l, r;
        cin >> l >> r;
        q[i] = {l, r, i};
    }

    sort(q + 1, q + m + 1);

    int j = 1;
    for (int i = 1; i <= m; ++i) {
        int l = q[i].l, r = q[i].r, id = q[i].id;
        while (j <= r) {
            if (last[w[j]]) add(last[w[j]], -1);
            add(j, 1);
            last[w[j]] = j;
            j++;
        }
        ans[id] = get(r) - get(l - 1);
    }

    for (int i = 1; i <= m; ++i) cout << ans[i] << '\n';

    return 0;
}

离线-线段树解法

因为树状数组维护的是位置的前缀和 , 线段树也可以解决

代码实现

cpp 复制代码
#include <bits/stdc++.h>

using namespace std;

const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;

int n, m;
int w[N];
struct Ask {
    int l, r, id;

    bool operator< (const Ask &a) const {
        return r < a.r;
    }
} q[M];
int last[S], ans[M];

struct Node {
    int l, r;
    int sum;
} tr[N << 2];

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void build(int u, int l, int r) {
    tr[u] = {l, r, 0};
    if (l == r) return;

    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void modify(int u, int pos, int x) {
    if (tr[u].l == tr[u].r) {
        tr[u].sum += x;
        return;
    }

    int mid = tr[u].l + tr[u].r >> 1;
    if (pos <= mid) modify(u << 1, pos, x);
    if (pos > mid) modify(u << 1 | 1, pos, x);
    pushup(u);
}

int query(int u, int ql, int qr) {
    if (tr[u].l >= ql && tr[u].r <= qr) return tr[u].sum;
    int ans = 0;
    int mid = tr[u].l + tr[u].r >> 1;
    if (ql <= mid) ans += query(u << 1, ql, qr);
    if (qr > mid) ans += query(u << 1 | 1, ql, qr);
    return ans;
}

inline int lowbit(int x) {
    return x & -x;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n;
    for (int i = 1; i <= n; ++i) cin >> w[i];

    cin >> m;
    for (int i = 1; i <= m; ++i) {
        int l, r;
        cin >> l >> r;
        q[i] = {l, r, i};
    }

    build(1, 1, n);
    sort(q + 1, q + m + 1);

    int j = 1;
    for (int i = 1; i <= m; ++i) {
        auto [l, r, id] = q[i];
        while (j <= r) {
            if (last[w[j]]) modify(1, last[w[j]], -1);
            modify(1, j, 1);
            last[w[j]] = j;
            j++;
        }
        ans[id] = query(1, l, r);
    }

    for (int i = 1; i <= m; ++i) cout << ans[i] << '\n';

    return 0;
}

离线-莫队算法

假设最坏情况下, 每次查询一个区间 n n n, 开一个数组用于记录每个数字出现的次数 , 算法时间复杂度 O ( q n ) O(qn) O(qn), 无法通过

莫队优化

开一个 c n t cnt cnt数组用来记录每个数字出现的次数

( 1 ) (1) (1) 对查询区间进行排序

假设当前查询区间是 [ i , j ] [i, j] [i,j]

蓝色部分是下一段查询的区间

( 2 ) (2) (2)对于每个区间

  1. 首先移动 j j j指针, 移动到下一次查询的右端点 上, 同时对于当前数字 x x x, 如果未出现过那么不同数字的数量 + 1 +1 +1 , 否则不同数字的出现次数不发生变化 , 同时累计 c n t cnt cnt数组的值

  2. 再将指针 i i i移动到下一次查询的左端点, 对于当前需要删除的数字 x x x, 如果出现次数大于 1 1 1, 那么 c n t ( x ) − 1 cnt(x) - 1 cnt(x)−1, 不对答案产生影响, 否则答案 − 1 -1 −1

因为每次移动指针最坏情况下是 O ( n ) O(n) O(n)次, 算法时间复杂度最坏 O ( q n ) O(qn) O(qn), 还是没办法通过

算法核心优化 :因为算法瓶颈在指针会移动 O ( n q ) O(nq) O(nq)次数, 尝试使得右指针是单调的 (不会向回移动), 左指针分块, 具体的来说

区间左端点按照分块的编号排序 , 双关键字排序

  • 如果分块编号 相同按照区间右端点从小到大排序
  • 如果分块编号不同, 块小的在前面

将所有查询分为 n \sqrt n n 块, 每一块长度是也是 n \sqrt n n , 块内部区间的右端点是递增的

对于右指针来说, 块内 走的次数不会超过 n n n, 一共 n \sqrt n n 块, 最多移动 n n n \sqrt n nn 次

左指针分为两种情况

  • 块内最多移动 n \sqrt n n 次, 最多 q q q个询问, 算法时间复杂度最差 O ( q n ) O(q \sqrt n) O(qn )
  • 块间最多移动 2 n 2 \sqrt n 2n 次, 最多跨越 n − 1 \sqrt n - 1 n −1个块, 最差 O ( 2 n ) O(2n) O(2n)

因此左指针的最差时间复杂度是 O ( q n ) O(q \sqrt n) O(qn )

左右时间取最大值, 因此优化后的算法时间复杂度 最坏情况下是 O ( q n ) O(q \sqrt n) O(qn )

假设块的大小是 a a a, 右指针的移动次数是 n 2 a \frac{n ^ 2}{a} an2, 左指针的最大移动次数是 m a ma ma, 也就是
a = n 2 m a = \sqrt {\frac{n ^ 2}{m}} a=mn2

左右指针移动的次数相当

如果 a = n a = \sqrt n a=n 比较慢, 尝试将 a a a变为 n 2 m \sqrt {\frac{n ^ 2}{m}} mn2

核心代码

cpp 复制代码
    // i表示左端点, j表示右端点
    for (int k = 0, i = 1, j = 0, res = 0; k < m; ++k) {
        // l, r分别代表当前区间的左右端点
        auto [l, r, id] = q[k];
        while (j < r) add(w[++j], res);
        while (j > r) del(w[j--], res);
        while (i < l) del(w[i++], res);
        while (i > l) add(w[--i], res);

        ans[id] = res;
    }

示例代码

cpp 复制代码
#include <bits/stdc++.h>

using namespace std;

const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;

int n, m, len;
int w[N], cnt[S], ans[M];
struct Ask {
    int l, r, id;
} q[M];

int get(int i) {
    return i / len;
}

bool cmp(const Ask &a, const Ask &b) {
    int ba = get(a.l), bb = get(b.l);
    if (ba == bb) return a.r < b.r;
    return ba < bb;
}

void add(int x, int &ans) {
    if (!cnt[x]) ans++;
    cnt[x]++;
}

void del(int x, int &ans) {
    cnt[x]--;
    if (!cnt[x]) ans--;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n;
    for (int i = 1; i <= n; ++i) cin >> w[i];
    cin >> m;

    // 计算块长
    len = sqrt(n * n / m + 1);

    for (int i = 0; i < m; ++i) {
        int l, r;
        cin >> l >> r;
        q[i] = {l, r, i};
    }

    sort(q, q + m, cmp);

    for (int k = 0, i = 1, j = 0, res = 0; k < m; ++k) {
        auto [l, r, id] = q[k];
        while (j < r) add(w[++j], res);
        while (j > r) del(w[j--], res);
        while (i < l) del(w[i++], res);
        while (i > l) add(w[--i], res);

        ans[id] = res;
    }

    for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
    return 0;
}

在线-主席树(持久化线段树)

假设对于当前数字上一次出现的位置是 l a s t i last_i lasti

那么当前区间 [ l , r ] [l, r] [l,r]内是第一次出现该数字等价于 l a s t i < l last_i < l lasti<l, 为了方便统计定义 g ( i ) g(i) g(i)表示 l a s t i + 1 last_i + 1 lasti+1

那么答案就是统计在区间 [ l , r ] [l, r] [l,r]内 g ( i ) ≤ l g(i) \le l g(i)≤l的数字的个数, 可以使用主席树 维护 g ( i ) g(i) g(i)的取值

查询的时候直接查询 ≤ l \le l ≤l的所有 g ( i ) g(i) g(i)就是答案

算法时间复杂度 O ( m log ⁡ n ) O(m \log n) O(mlogn)

代码实现

  • 构建多个版本根节点数组 r o o t [ N ] root[N] root[N]
  • 构建初始版本主席树
  • 对每个 g ( i ) g(i) g(i)执行插入操作 O ( n log ⁡ n ) O(n \log n) O(nlogn)
  • 计算区间和, 版本查询 l − 1 , r l - 1, r l−1,r两个版本 O ( m log ⁡ n ) O(m \log n) O(mlogn)
cpp 复制代码
#include <bits/stdc++.h>

using namespace std;

const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;

int n, m;
int w[N];
struct Node {
    int ls, rs;
    int s;
} tr[4 * N + 17 * N];
int last[S], g[N];
int root[M], idx;

int build(int l, int r) {
    int u = ++idx;
    tr[u] = {0, 0, 0};
    if (l == r) return u;

    int mid = l + r >> 1;
    tr[u].ls = build(l, mid);
    tr[u].rs = build(mid + 1, r);
    return u;
}

int insert(int p, int l, int r, int x) {
    int u = ++idx;
    tr[u] = tr[p];
    tr[u].s = tr[p].s + 1;
    if (l == r) return u;
    int mid = l + r >> 1;
    if (x <= mid) tr[u].ls = insert(tr[p].ls, l, mid, x);
    if (x > mid) tr[u].rs = insert(tr[p].rs, mid + 1, r, x);
    return u;
}

int query(int p, int q, int l, int r, int ql, int qr) {
    if (l >= ql && r <= qr) return tr[q].s - tr[p].s;
    int ans = 0;
    int mid = l + r >> 1;
    if (ql <= mid) ans += query(tr[p].ls, tr[q].ls, l, mid, ql, qr);
    if (qr > mid) ans += query(tr[p].rs, tr[q].rs, mid + 1, r, ql, qr);
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n;
    root[0] = build(1, n);
    
    for (int i = 1; i <= n; ++i) {
        cin >> w[i];
        g[i] = last[w[i]] + 1;
        last[w[i]] = i;
        root[i] = insert(root[i - 1], 1, n, g[i]);
    }

    cin >> m;
    while (m--) {
        int l, r;
        cin >> l >> r;
        // 注意这里查询的是 <= l的g[i]的数量
        int ans = query(root[l - 1], root[r], 1, n, 1, l);
        cout << ans << '\n';
    }

    return 0;
}
相关推荐
漫随流水3 分钟前
leetcode算法(145.二叉树的后序遍历)
数据结构·算法·leetcode·二叉树
Tony_yitao9 分钟前
22.华为OD机试真题:数组拼接(Java实现,100分通关)
java·算法·华为od·algorithm
2501_9418752811 分钟前
在东京复杂分布式系统中构建统一可观测性平台的工程设计实践与演进经验总结
c++·算法·github
sonadorje13 分钟前
梯度下降法的迭代步骤
算法·机器学习
漫随流水16 分钟前
leetcode算法(94.二叉树的中序遍历)
数据结构·算法·leetcode·二叉树
Jacen.L22 分钟前
SIGABRT (6) 中止信号详解
c++
王老师青少年编程42 分钟前
信奥赛C++提高组csp-s之并查集(案例实践)2
数据结构·c++·并查集·csp·信奥赛·csp-s·提高组
范纹杉想快点毕业1 小时前
嵌入式通信核心架构:从状态机、环形队列到多协议融合
linux·运维·c语言·算法·设计模式
智源研究院官方账号1 小时前
众智FlagOS 1.6发布,以统一架构推动AI硬件、软件技术生态创新发展
数据库·人工智能·算法·架构·编辑器·硬件工程·开源软件