【笔记】树状数组

【笔记】树状数组 目录


简介

树状数组是一种树形数据结构,支持在 O ( log ⁡ n ) O(\log n) O(logn) 的时间复杂度内进行 单点修改查询前缀和 的操作。

  • 优点:常数小,码量小,操作灵活简便。
  • 缺点:只能用来维护具有 结合律可差分 的信息。例如:区间和、积等,而不能维护区间最大(最小)值。

引入

现在想要让你实现两个操作:

  1. 单点修改
  2. 查询 1 , x 1,x 1,x 的和

在没有学过树状数组的时候你会怎么做?

1. 直接暴力

单点修改虽然方便,但前缀和是 O ( n ) O(n) O(n) 复杂度。

2. 维护前缀和数组

这样做虽然查询是 O ( 1 ) O(1) O(1) 了,但单点修改又是 O ( n ) O(n) O(n)。

总结

  • 暴力
    • 修改: O ( 1 ) O(1) O(1)
    • 查询: O ( n ) O(n) O(n)
  • 前缀和
    • 修改: O ( n ) O(n) O(n)
    • 查询: O ( 1 ) O(1) O(1)

那么我们不妨考虑一个折中的办法,两种操作都是 O ( log ⁡ n ) O(\log n) O(logn) 的复杂度。


定义

注:这里的数值表示的是该区间所有元素的和,也就是这个节点左下方的所有直接相关节点的总和。

例如:权值为 31 31 31 的节点表示的是权值分别为 19 , 10 , 1 19,10,1 19,10,1 的节点以及原数组中下表为 8 8 8 的元素之和。

显然,我们能求出原数组为

8 , 6 , 1 , 4 , 5 , 5 , 1 , 1 , 3 , 2 , 1 , 4 , 9 , 0 , 7 , 4 \] \[8,6,1,4,5,5,1,1,3,2,1,4,9,0,7,4\] \[8,6,1,4,5,5,1,1,3,2,1,4,9,0,7,4

这里插一句话:树状数组可以近似看成线段树去掉所有右儿子构成的树。


前置知识: lowbit ⁡ \operatorname{lowbit} lowbit 操作

一个二进制数的 lowbit ⁡ \operatorname{lowbit} lowbit 值就是这个数末尾第一个非零的位置的权值。

举个例子: 10001 0 ( 2 ) 100010_{(2)} 100010(2)

这个数的 lowbit ⁡ \operatorname{lowbit} lowbit 值是 1 0 ( 2 ) 10_{(2)} 10(2),即 2 ( 10 ) 2_{(10)} 2(10)。

那么这个怎么用代码实现呢?

cpp 复制代码
void lowbit(int x)
{
	return x & -x;
}

什么?你问为什么这么简单??

这都不知道,赶紧退役吧 h h \color{white}{这都不知道,赶紧退役吧hh} 这都不知道,赶紧退役吧hh

这里涉及到补码的概念。

一个二进制数的补码就是其二进制上的每一位都按位取反之后再 + 1 +1 +1。

还是那个数: 10001 0 ( 2 ) 100010_{(2)} 100010(2)

先按位取反: 01110 1 ( 2 ) 011101_{(2)} 011101(2)

再加一: 1111 0 ( 2 ) 11110_{(2)} 11110(2)

我们惊奇地发现,它们的后两位竟然是一样的!!!

我们把它们进行按位与运算 &,得到的结果是 1 0 ( 2 ) 10_{(2)} 10(2),即 2 ( 10 ) 2_{(10)} 2(10),与我们刚才进行手动 lowbit ⁡ \operatorname{lowbit} lowbit 运算的结果相同。

在计算机的运算过程中,由于是按照补码储存的,所以我们需要的 ~x + 1 就可以写成 -x

因此 lowbit ⁡ \operatorname{lowbit} lowbit 才能写成 x & -x


区间的表示方法

对于每个标号为 x x x 的节点,我们发现它父节点的标号为 x + lowbit x x+\text{lowbit}\ x x+lowbit x。

而每个区间的范围都是 ( x − lowbit ( x ) , x ] (x-\text{lowbit}(x),x] (x−lowbit(x),x]。


操作

单点修改

对于每个被修改的点,我们需要找到它的所有祖先节点并都进行修改操作。

考虑到它们标号的关系,我们只要每次加一个 lowbit(x) \text{lowbit(x)} lowbit(x) 就能找到所有祖先节点了。

代码:

cpp 复制代码
void add(int x, int c) // 将第 x 个数加 c
{
	for (int i = x; i <= n; i += lowbit(i))
		tr[i] += c;
}

前缀和查询

实践是检验真理的唯一标准。

经过我们的实践,找到该节点前面的所有节点,只需要每次减 lowbit(x) \text{lowbit(x)} lowbit(x) 即可。

代码:

cpp 复制代码
void query(int x) // 查询 1~x 的和
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
		res += tr[i];
	return res;
}

任意区间查询

我们都知道前缀和的性质。

∑ i = l r w i = ∑ i = 1 r w i − ∑ i = 1 l − 1 w i \sum_{i=l}^{r}w_i=\sum_{i=1}^{r}w_i-\sum_{i=1}^{l-1}w_i i=l∑rwi=i=1∑rwi−i=1∑l−1wi

代码:

cpp 复制代码
void Query(int l, int r) // 查询 [l,r] 的和
{
	return query(r) - query(l - 1);
}

例题1: 单点修改,区间查询

原题链接:P3374 【模板】树状数组 1

操作和上面的相同,直接上代码:

cpp 复制代码
#include <iostream>

using namespace std;

const int N = 500010;

int n, m;
int a[N];
int tr[N];

int lowbit(int x)
{
	return x & -x;
}

void add(int x, int c)
{
	for (int i = x; i <= n; i += lowbit(i))
		tr[i] += c;
}

int sum(int x)
{
	int res = 0;
	for (int i = x; i; i -= lowbit(i))
		res += tr[i];
	return res;
}

int main()
{
	int op, x, y;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i ++ )
		scanf("%d", &a[i]), add(i, a[i]);
	
	while (m -- )
	{
		scanf("%d%d%d", &op, &x, &y);
		if (op == 1) add(x, y);
		else printf("%d\n", sum(y) - sum(x - 1));
	}
	
	return 0;
}

例题2: 区间修改,单点查询

原题链接:P3368 【模板】树状数组 2

同一道题,思路已经在昨天的 【笔记】线段树 里面讲了,无非是维护一个差分数组。

代码:

cpp 复制代码
#include <iostream>

using namespace std;

const int N = 500010;

int n, m;
int a[N], b[N];
int tr[N];

int lb(int x)
{
	return x & -x;
}

void add(int x, int v)
{
	for (int i = x; i <= n; i += lb(i))
		tr[i] += v;
}

int q(int x)
{
	int res = 0;
	for (int i = x; i; i -= lb(i))
		res += tr[i];
	return res;
}

int main()
{
	cin >> n >> m;
	for (int i = 1; i <= n; i ++ )
		cin >> a[i], b[i] = a[i] - a[i - 1], add(i, b[i]);
	
	while (m -- )
	{
		int op, x, y, k;
		cin >> op >> x;
		if (op == 1)
		{
			cin >> y >> k;
			add(x, k), add(y + 1, -k);
		}
		else cout << q(x) << endl;
	}
}

例题3: 区间修改,区间查询

原题链接:P3372 【模板】线段树 1

不要说我用线段树的题练习树状数组,我找不到树状数组的模板题才用的这个

考虑用树状数组 tr[] 维护差分数组

则求原数组的前缀和

{ a 1 = d 1 a 2 = d 1 + d 2 a 3 = d 1 + d 2 + d 3 . . . . . . a n = d 1 + d 2 + . . . + d n \left\{\begin{matrix} a_1& =& d_1& & & & & & & \\ a_2& =& d_1& +& d_2& & & & & \\ a_3& =& d_1& +& d_2& +& d_3& & & \\ .& .& .& .& .& .& & & & \\ a_n& =& d_1& +& d_2& +& ...& +& d_n& \\ \end{matrix}\right. ⎩ ⎨ ⎧a1a2a3.an===.=d1d1d1.d1++.+d2d2.d2+.+d3...+dn

s i = ∑ i = 1 n a i = { d 1 d 1 + d 2 d 1 + d 2 + d 3 . . . . . . d 1 + d 2 + . . . + d n s_i=\sum_{i=1}^{n}a_i=\left\{\begin{matrix} d_1& & & & & & & \\ d_1& +& d_2& & & & & \\ d_1& +& d_2& +& d_3& & & \\ .& .& .& .& .& .& & & & \\ d_1& +& d_2& +& ...& +& d_n& \\ \end{matrix}\right. si=i=1∑nai=⎩ ⎨ ⎧d1d1d1.d1++.+d2d2.d2+.+d3.....+dn

我们考虑把后面的矩阵补全:

s i = ( n + 1 ) × ∑ i = 1 n d i − ∑ i = 1 n ( i × d i ) s_i=(n+1) \times \sum_{i=1}^{n}d_i-\sum_{i=1}^{n}(i \times d_i) si=(n+1)×i=1∑ndi−i=1∑n(i×di)

所以我们需要两个树状数组,tr1[] 维护差分数组,tr2[] 维护 i × d i i \times d_i i×di

代码:

cpp 复制代码
#include <iostream>

using namespace std;

typedef long long LL;

const LL N = 1000010;

LL n, m;
LL a[N];
LL t1[N], t2[N];

inline LL lowbit(LL x)
{
    return x & -x;
}

inline void add(LL t[], LL x, LL c)
{
    for (LL i = x; i <= n; i += lowbit(i))
        t[i] += c;
}

inline LL sum(LL t[], LL x)
{
    LL res = 0;
    for (LL i = x; i; i -= lowbit(i))
        res += t[i];
    return res;
}

inline LL psum(LL x)
{
    return sum(t1, x) * (x + 1) - sum(t2, x);
}

int main()
{
    scanf("%lld%lld", &n, &m);
    for (LL i = 1; i <= n; i ++ ) scanf("%lld", &a[i]);
    for (LL i = 1; i <= n; i ++ )
    {
        LL b = a[i] - a[i - 1];
        add(t1, i, b);
        add(t2, i, b * i);
    }

    while (m -- )
    {
        char op[2];
        LL l, r, d;
        scanf("%s%lld%lld", op, &l, &r);
        if (op[0] == '2')
        {
            printf("%lld\n", psum(r) - psum(l - 1));
        }
        else
        {
            scanf("%lld", &d);
            add(t1, l, d), add(t2, l, l * d);
            add(t1, r + 1, -d), add(t2, r + 1, -d * (r + 1));
        }
    }

    return 0;
}

最后,如果觉得对您有帮助的话,点个赞再走吧!

(后附极限卡常代码,70ms,较优解)

cpp 复制代码
#define qwq optimize
#pragma GCC qwq(1)
#pragma GCC qwq(2)
#pragma GCC qwq(3)
#pragma GCC qwq("Ofast")
#pragma GCC qwq("inline")
#pragma GCC qwq("-fgcse")
#pragma GCC qwq("-fgcse-lm")
#pragma GCC qwq("-fipa-sra")
#pragma GCC qwq("-ftree-pre")
#pragma GCC qwq("-ftree-vrp")
#pragma GCC qwq("-fpeephole2")
#pragma GCC qwq("-ffast-math")
#pragma GCC qwq("-fsched-spec")
#pragma GCC qwq("unroll-loops")
#pragma GCC qwq("-falign-jumps")
#pragma GCC qwq("-falign-loops")
#pragma GCC qwq("-falign-labels")
#pragma GCC qwq("-fdevirtualize")
#pragma GCC qwq("-fcaller-saves")
#pragma GCC qwq("-fcrossjumping")
#pragma GCC qwq("-fthread-jumps")
#pragma GCC qwq("-funroll-loops")
#pragma GCC qwq("-fwhole-program")
#pragma GCC qwq("-freorder-blocks")
#pragma GCC qwq("-fschedule-insns")
#pragma GCC qwq("inline-functions")
#pragma GCC qwq("-ftree-tail-merge")
#pragma GCC qwq("-fschedule-insns2")
#pragma GCC qwq("-fstrict-aliasing")
#pragma GCC qwq("-fstrict-overflow")
#pragma GCC qwq("-falign-functions")
#pragma GCC qwq("-fcse-skip-blocks")
#pragma GCC qwq("-fcse-follow-jumps")
#pragma GCC qwq("-fsched-interblock")
#pragma GCC qwq("-fpartial-inlining")
#pragma GCC qwq("no-stack-protector")
#pragma GCC qwq("-freorder-functions")
#pragma GCC qwq("-findirect-inlining")
#pragma GCC qwq("-fhoist-adjacent-loads")
#pragma GCC qwq("-frerun-cse-after-loop")
#pragma GCC qwq("inline-small-functions")
#pragma GCC qwq("-finline-small-functions")
#pragma GCC qwq("-ftree-switch-conversion")
#pragma GCC qwq("-fqwq-sibling-calls")
#pragma GCC qwq("-fexpensive-optimizations")
#pragma GCC qwq("-funsafe-loop-optimizations")
#pragma GCC qwq("inline-functions-called-once")
#pragma GCC qwq("-fdelete-null-pointer-checks")
#include <iostream>
#include <cstdio>

#define lb(x) (x & (-x))

using namespace std;

typedef long long LL;

const LL N = 100010;

LL n, m;
LL a[N];
LL t1[N], t2[N];

char *p1, *p2, buf[N];
#define nc() (p1 == p2 && (p2 = (p1 = buf) +\
fread(buf, 1, N, stdin), p1 == p2) ? EOF : *p1 ++ )
LL read()
{
    LL x = 0, f = 1;
    char ch = nc();
    while (ch < 48 || ch > 57)
    {
        if (ch == '-') f = -1;
        ch = nc();
    }
    while (ch >= 48 && ch <= 57)
        x = (x << 3) + (x << 1) + (ch ^ 48), ch = nc();
    return x * f;
}

char obuf[N], *p3 = obuf;
#define putchar(x) (p3 - obuf < N) ? (*p3 ++ = x) :\
(fwrite(obuf, p3 - obuf, 1, stdout), p3 = obuf, *p3 ++ = x)
inline void write(LL x)
{
    if (!x)
    {
        putchar('0');
        return;
    }
    LL len = 0, k1 = x, c[40];
    if (k1 < 0) k1 = -k1, putchar('-');
    while (k1) c[len ++ ] = k1 % 10 ^ 48, k1 /= 10;
    while (len -- ) putchar(c[len]);
}

inline void add(LL t[], LL x, LL c)
{
    for (LL i = x; i <= n; i += lb(i))
        t[i] += c;
}

inline LL sum(LL t[], LL x)
{
    LL res = 0;
    for (LL i = x; i; i -= lb(i))
        res += t[i];
    return res;
}

inline LL psum(LL x)
{
    return sum(t1, x) * (x + 1) - sum(t2, x);
}

int main()
{
    n = read(), m = read();
    for (LL i = 1; i <= n; i ++ ) a[i] = read();
    for (LL i = 1; i <= n; i ++ )
    {
        LL b = a[i] - a[i - 1];
        add(t1, i, b);
        add(t2, i, b * i);
    }

    LL op, l, r, d;
    while (m -- )
    {
        op = read(), l = read(), r = read();
        if (op == 2) write(psum(r) - psum(l - 1)), putchar(10);
        else
        {
            d = read();
            add(t1, l, d), add(t2, l, l * d);
            add(t1, r + 1, -d), add(t2, r + 1, -d * (r + 1));
        }
    }
	fwrite(obuf, p3 - obuf, 1, stdout);
    return 0;
}
相关推荐
wabs6664 小时前
关于贪心算法的思考
算法·贪心算法
社交怪人5 小时前
【判断大小】信息学奥赛一本通C语言解法(题号1043)
算法
lengxuemo5 小时前
ICC2学习笔记之Placement and Optimization
笔记·学习
Snasph5 小时前
GNU Make 用户手册(中文版)
服务器·算法·gnu
江澎涌5 小时前
拆解与 AI 的一次对话
人工智能·算法·程序员
sheeta19986 小时前
LeetCode 每日一题笔记 日期:2026.06.02 题目:3635. 最早完成陆地和水上游乐设施的时间 II
笔记·算法·leetcode
Lsk_Smion6 小时前
力扣实训 _ [102].层序遍历--前序--后续_递归与非递归的实现
数据结构·算法·leetcode
小满Autumn7 小时前
MVVM Light 架构笔记:定位器、命令、消息与 IoC 实践
笔记·学习·架构·c#·上位机·mvvm
Lsk_Smion7 小时前
力扣实训 _ [25].K个一组链表
数据结构·链表
小欣加油7 小时前
leetcode3751 范围内总波动值I
java·数据结构·c++·算法·leetcode