树状数组
简介¶
树状数组和线段树具有相似的功能,但他俩毕竟还有一些区别:树状数组能有的操作,线段树一定有;线段树有的操作,树状数组不一定有。但是树状数组的代码要比线段树短,思维更清晰,速度也更快,在解决一些单点修改的问题时,树状数组是不二之选。
原理¶
下面这张图展示了树状数组的工作原理:
这个结构和线段树有些类似:用一个大节点表示一些小节点的信息,进行查询的时候只需要查询一些大节点而不是所有的小节点。
最上面的八个方块就代表数组
他们下面的参差不齐的剩下的方块就代表数组
从图中可以看出:
如果要计算数组
从
用法及操作¶
那么问题来了,怎么知道 lowbit
:
1 2 3 4 5 6 7 8 9 | // C++ Version
int lowbit(int x) {
// x 的二进制表示中,最低位的 1 的位置。
// lowbit(0b10110000) == 0b00010000
// ~~~^~~~~
// lowbit(0b11100100) == 0b00000100
// ~~~~~^~~
return x & -x;
}
|
1 2 3 4 5 6 7 8 9 10 | # Python Version
def lowbit(x):
"""
x 的二进制表示中,最低位的 1 的位置。
lowbit(0b10110000) == 0b00010000
~~~^~~~~
lowbit(0b11100100) == 0b00000100
~~~~~^~~
"""
return x & -x
|
注释说明了 lowbit
的意思,对于
发现第一个
在常见的计算机中,有符号数采用补码表示。在补码表示下,数
使用 lowbit 函数,我们可以实现很多操作,例如单点修改,将
1 2 3 4 5 6 7 | // C++ Version
void add(int x, int k) {
while (x <= n) { // 不能越界
c[x] = c[x] + k;
x = x + lowbit(x);
}
}
|
1 2 3 4 5 | # Python Version
def add(x, k):
while x <= n: # 不能越界
c[x] = c[x] + k
x = x + lowbit(x)
|
前缀求和:
1 2 3 4 5 6 7 8 9 | // C++ Version
int getsum(int x) { // a[1]..a[x]的和
int ans = 0;
while (x >= 1) {
ans = ans + c[x];
x = x - lowbit(x);
}
return ans;
}
|
1 2 3 4 5 6 7 | # Python Version
def getsum(x): # a[1]..a[x]的和
ans = 0
while x >= 1:
ans = ans + c[x]
x = x - lowbit(x)
return ans
|
区间加 & 区间求和¶
若维护序列
进行推导
区间和可以用两个前缀和相减得到,因此只需要用两个树状数组分别维护
代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | // C++ Version
int t1[MAXN], t2[MAXN], n;
inline int lowbit(int x) { return x & (-x); }
void add(int k, int v) {
int v1 = k * v;
while (k <= n) {
t1[k] += v, t2[k] += v1;
k += lowbit(k);
}
}
int getsum(int *t, int k) {
int ret = 0;
while (k) {
ret += t[k];
k -= lowbit(k);
}
return ret;
}
void add1(int l, int r, int v) {
add(l, v), add(r + 1, -v); // 将区间加差分为两个前缀加
}
long long getsum1(int l, int r) {
return (r + 1ll) * getsum(t1, r) - 1ll * l * getsum(t1, l - 1) -
(getsum(t2, r) - getsum(t2, l - 1));
}
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | # Python Version
t1 = [0] * MAXN, t2 = [0] * MAXN; n = 0
def lowbit(x):
return x & (-x)
def add(k, v):
v1 = k * v
while k <= n:
t1[k] = t1[k] + v; t2[k] = t2[k] + v1
k = k + lowbit(k)
def getsum(t, k):
ret = 0
while k:
ret = ret + t[k]
k = k - lowbit(k)
return ret
def add1(l, r, v):
add(l, v)
add(r + 1, -v)
def getsum1(l, r):
return (r) * getsum(t1, r) - l * getsum(t1, l - 1) - \
(getsum(t2, r) - getsum(t2, l - 1))
|
Tricks¶
每一个节点的值是由所有与自己直接相连的儿子的值求和得到的。因此可以倒着考虑贡献,即每次确定完儿子的值后,用自己的值更新自己的直接父亲。
1 2 3 4 5 6 7 8 9 | // C++ Version
// O(n)建树
void init() {
for (int i = 1; i <= n; ++i) {
t[i] += a[i];
int j = i + lowbit(i);
if (j <= n) t[j] += t[i];
}
}
|
1 2 3 4 5 6 7 | # Python Version
def init():
for i in range(1, n + 1):
t[i] = t[i] + a[i]
j = i + lowbit(i)
if j <= n:
t[j] = t[j] + t[i]
|
参考 "可持久化线段树" 章节中,关于求区间第
因此可以想到算法:如果已经找到
- 求出
depth=\left \lfloor \log_2n \right \rfloor - 计算
t=\sum_{i=x+1}^{x+2^{depth}}a_i - 如果
sum+t \le k 2^{depth} x x - 将
depth depth
1 2 3 4 5 6 7 8 9 10 11 12 13 | // C++ Version
// 权值树状数组查询第k小
int kth(int k) {
int cnt = 0, ret = 0;
for (int i = log2(n); ~i; --i) { // i 与上文 depth 含义相同
ret += 1 << i; // 尝试扩展
if (ret >= n || cnt + t[ret] >= k) // 如果扩展失败
ret -= 1 << i;
else
cnt += t[ret]; // 扩展成功后 要更新之前求和的值
}
return ret + 1;
}
|
1 2 3 4 5 6 7 8 9 10 11 12 | # Python Version
# 权值树状数组查询第 k 小
def kth(k):
cnt = 0; ret = 0
i = log2(n) # i 与上文 depth 含义相同
while ~i:
ret = ret + (1 << i) # 尝试扩展
if ret >= n or cnt + t[ret] >= k: # 如果扩展失败
ret = ret - (1 << i)
else:
cnt = cnt + t[ret] # 扩展成功后 要更新之前求和的值
return ret + 1
|
时间戳优化:
对付多组数据很常见的技巧。如果每次输入新数据时,都暴力清空树状数组,就可能会造成超时。因此使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | // C++ Version
// 时间戳优化
int tag[MAXN], t[MAXN], Tag;
void reset() { ++Tag; }
void add(int k, int v) {
while (k <= n) {
if (tag[k] != Tag) t[k] = 0;
t[k] += v, tag[k] = Tag;
k += lowbit(k);
}
}
int getsum(int k) {
int ret = 0;
while (k) {
if (tag[k] == Tag) ret += t[k];
k -= lowbit(k);
}
return ret;
}
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | # Python Version
# 时间戳优化
tag = [0] * MAXN; t = [0] * MAXN; Tag = 0
def reset():
Tag = Tag + 1
def add(k, v):
while k <= n:
if tag[k] != Tag:
t[k] = 0
t[k] = t[k] + v
tag[k] = Tag
k = k + lowbit(k)
def getsum(k):
ret = 0
while k:
if tag[k] == Tag:
ret = ret + t[k]
k = k - lowbit(k)
return ret
|
例题¶
build本页面最近更新:,更新历史
edit发现错误?想一起完善? 在 GitHub 上编辑此页!
people本页面贡献者:HeRaNO, Zhoier, Ir1d, Xeonacid, wangdehu, ouuan, ranwen, ananbaobeichicun, Ycrpro
copyright本页面的全部内容在 CC BY-SA 4.0 和 SATA 协议之条款下提供,附加条款亦可能应用