简介

树状数组可以看做简化版的线段树,它进行单点修改和区间查询的常数是比线段树更优的。

对于原数组 a[n]a[n],树状数组 c[n]c[n] 是一个等长的数组,并且对于任意 c[i]c[i] 表示以 a[i]a[i] 结尾且长度为 lowbit(i)\text{lowbit}(i) 的区间和。

下面用 f(x)f(x) 代表 lowbit(x)\text{lowbit}(x)L(x),R(x)L(x),R(x) 代表 xx 覆盖的左闭右开区间,证明其正确性。

A0=k2i+1+2iA_0=k2^{i+1}+2^i,则 f(A0)=2i,L(A0)=A0f(A0)=k2i+1,R(A0)=A0f(A_0)=2^i,L(A_0)=A_0-f(A_0)=k2^{i+1},R(A_0)=A_0

A1=A0+f(A0)=(k+1)2i+1A_1=A_0+f(A_0)=(k+1)2^{i+1},那么 L(A1)=A1f(A1)=(k+1)2i+1f(A1)L(A_1)=A_1-f(A_1)=(k+1)2^{i+1}-f(A_1)

由于 f(A1)2i+1f(A_1)\ge 2^{i+1},所以 L(A1)k2i+1=L(A0)L(A_1)\le k2^{i+1}=L(A_0),且 R(A1)=A1=(k+1)2i+1>A0=R(A0)R(A_1)=A_1=(k+1)2^{i+1}\gt A_0=R(A_0),所以 A1A_1 一定覆盖 A0A_0

所以一个点的父节点是 x+f(x)x+f(x),它左边与它相邻且不是它子结点的点是 L(x)=xf(x)L(x)=x-f(x)

模板题

已知一个数列,你需要进行下面两种操作:

  • 将某一个数加上 xx

  • 求出某区间每一个数的和

题目链接:P3374

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <cstdio>
#include <algorithm>
using namespace std;

const int N = 500010;
int tr[N], n, m;

#define lowbit(x) ((x)&(-x))

void add(int p, int v) {
for (; p < N; p += lowbit(p))
tr[p] += v;
}

int query(int p) {
int res = 0;
for (; p; p -= lowbit(p))
res += tr[p];
return res;
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
int v; scanf("%d", &v);
add(i, v);
}

while (m--) {
int op, x, y;
scanf("%d%d%d", &op, &x, &y);
if (op == 1) add(x, y);
else printf("%d\n", query(y) - query(x-1));
}
return 0;
}

逆序对

定义:i<ji<jai>aja_i>a_j 就称为一个逆序对,统计逆序对数目。

题目链接:P1908

本题可以用归并排序那样的分治算法,并且它更好,但是这里我们用树状数组来解决这个问题。

首先注意到值域比较大,所以需要离散化。当枚举到 aia_i 时,我们需要知道前面有多少个数大于 aia_i,如果我们用树状数组来统计每个数字出现的次数,也就是求一下 [ai+1,n][a_i+1,n] 的区间和,其中 nn 是离散化后的最大值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;

typedef long long LL;
const int N = 500010;
int n, a[N];
LL tr[N];
vector<int> nums;

#define lowbit(x) ((x)&(-x))

void add(int p, int v) {
for (; p < N; p += lowbit(p))
tr[p] += v;
}

LL query(int p) {
LL res = 0;
for (; p; p -= lowbit(p))
res += tr[p];
return res;
}

LL query(int l, int r) {
return query(r) - query(l-1);
}

int find(int x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
LL res = 0;
for (int i = 1; i <= n; i++) {
int t = find(a[i]);
res += query(t+1, nums.size());
add(t, 1);
}
printf("%lld\n", res);
return 0;
}

The Battle of Chibi (LIS)

简单题意:给 TT 组数据,长度为 nn 的数列 aa 中,找出长度为 mm 的严格上升子序列的个数,答案对 1e9+7 取模。

题目链接:UVA12983UVA12983(Luogu)

这是一道 DP 题,但是可以用树状数组来加速。

  • 状态表示 f(i,j)f(i,j):长度为 iiaja_j 结尾的最长上升子序列的个数。

  • 状态转移:

    f(i,j)=ak<aj,k<jf(i1,k)f(i,j)=\sum_{a_k<a_j, k<j} f(i-1,k)

    只要我们在循环到 jj 时把之前的所有 f(i1,k)f(i-1,k)ak<aja_k<a_j 的值累加起来即可,这可以用树状数组优化,树状数组的索引是 aia_i 的值,值是每个 f(i1,k)f(i-1,k) 每次求的都是 [0,ai1][0,a_i-1] 间所有满足要求的值之和。

    由于牵扯到值域的问题,这里就需要用到离散化。

  • 初始化:f(1,j)=1f(1,j)=1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;

const int N = 1010, mod = 1e9+7;
vector<int> nums;
int tr[N], n, m;
int a[N], f[N][N];
#define lowbit(x) ((x)&(-x))

inline void add(int p, int v) {
for (; p < N; p += lowbit(p)) {
tr[p] = (tr[p] + v) % mod;
}
}

inline int query(int p) {
int res = 0;
for (; p; p -= lowbit(p))
res = (res + tr[p]) % mod;
return res;
}

int find(int x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}

int solve() {
// clear f
for (int j = 1; j <= n; j++) f[1][j] = 1;
for (int i = 2; i <= m; i++) {
memset(tr, 0, sizeof tr);
for (int j = 1; j <= n; j++) {
f[i][j] = query(a[j]-1);
add(a[j], f[i-1][j]);
}
}
int res = 0;
for (int j = 1; j <= n; j++) res = (res + f[m][j]) % mod;
return res;
}

int main() {
int T; scanf("%d", &T);
for (int C = 1; C <= T; C++) {
nums.clear();
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums.push_back(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
for (int i = 1; i <= n; i++) {
a[i] = find(a[i]);
}
printf("Case #%d: %d\n", C, solve());
}
return 0;
}

[THUPC2024 初赛] 二进制

小 F 给出了一个这里有一个长为 n106n\le 10^6 的二进制串 ss,下标从 11nn,且 i[1,n],si{0,1}\forall i\in[1,n],s_i\in \{0,1\},他想要删除若干二进制子串。

具体的,小 F 做出了 nn 次尝试。

在第 i[1,n]i\in[1,n] 次尝试中,他会先写出正整数 ii 的二进制串表示 tt(无前导零,左侧为高位,例如 1010 可以写为 10101010)。

随后找到这个二进制表示 ttss 中从左到右 第一次 出现的位置,并删除这个串。

注意,删除后左右部分的串会拼接在一起 形成一个新的串,请注意新串下标的改变。

若当前 tt 不在 ss 中存在,则小 F 对串 ss 不作出改变。

你需要回答每一次尝试中,ttss 中第一次出现的位置的左端点以及 ttss 中出现了多少次。

定义两次出现不同当且仅当出现的位置的左端点不同。

请注意输入输出效率。

题目链接:LOJ 6906

首先,对于一个数字 nn,它的二进制位数是 log2n+1\lfloor \log_2 n\rfloor + 1,为了方便阅读,下文用 logn\log n 代替。

我们每次删除的都是 1n1\sim n 的二进制串,因此它的长度应该小于等于 logn\log n,所以我们对于原串 ss 的每位 ii 都向后 logn\log n 个字符都扫描一遍,假设获得到的数字是 vv,那么就添加一个 viv\to i 的映射关系。

删除的时候就在这个位置的前 logn\log n,后 logv\log v 都更新一下。这样更新的原因是后 logv\log v 个元素是被删除掉的,所以需要删掉对应的映射关系;前 logn\log n 个向后关系有所改变,所以需要重新更新。

考虑如何处理坐标的变化,肯定是不能在映射里面修改的,这样复杂度太大了,所以用记录偏移量的方式,每次删除相当于在一个位置的后面的真实坐标都要减去 logv\log v,所以这里开一个树状数组就可以了。

然后考虑这个映射关系如何维护,由于 vvnn 数量级相同,所以可以直接开数组,然后这个数组内需要一个支持动态 add,delete,min,size 的数据结构,所以用可删堆来实现,相较于 std::set 常数要小一些。

删除之后再次访问的时候需要跳过删除掉的地方,这个操作可以用链表实现。由于更新的时候需要向前走,所以是双链表。

复杂度分析:删除的时候每次更新 O(logn)O(\log n) 个元素,每次更新的复杂度是 O(log2n)O(\log^2 n),所以删除的复杂度是 O(log3n)O(\log^3 n)

但是一共只会执行 O(nlogn)O(\frac{n}{\log n}) 次删除,所以最终的复杂度是 O(nlog2n)O(n\log^2 n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>
using namespace std;

#define lowbit(x) ((x)&(-x))
const int N = 1200010;

struct RemovableHeap {
priority_queue<int, vector<int>, greater<int>> v, d;
void remove(int x) {
d.push(x);
}

int top() {
while (d.size() && v.top() == d.top()) v.pop(), d.pop();
return v.top();
}

void push(int x) {
v.push(x);
}

int size() {
return v.size() - d.size();
}
} heap[N];

int n, a[N], lg2[N], nxt[N], pre[N], lg2n;

struct BIT {
int tr[N];

int query(int p) {
int res = 0;
for (; p; p -= lowbit(p)) res += tr[p];
return res;
}

void add(int p, int k) {
for (; p < N; p += lowbit(p)) tr[p] += k;
}
} bit;

void update(int s, bool add = true) {
if (a[s] == 0) return;

for (int i = s, now = 0, cnt = 0; i && cnt < lg2n; i = nxt[i], cnt++) {
now = now << 1 | a[i];
if (add) heap[now].push(s);
else heap[now].remove(s);
}
}

void remove(int s, int bits) {
int ne = s, raws = s;
bit.add(s, -bits);
for (int i = 0; i < bits; i++) {
update(ne, false);
ne = nxt[ne];
}

pre[ne] = pre[s];
for (int i = 0; s && i < lg2n; i++) {
s = pre[s];
update(s, false);
}
s = raws;
if (pre[s]) nxt[pre[s]] = ne;
for (int i = 0; s && i < lg2n; i++) {
s = pre[s];
update(s);
}
}

int main() {
static char s[N];
scanf("%d%s", &n, s+1);
for (int i = 1; i <= n; i++) a[i] = s[i] ^ 48;
// lg2[i] 表示 floor(log2(i)) + 1
lg2[1] = 0;
for (int i = 2; i <= n; i++) lg2[i] = lg2[i>>1] + 1;

for (int i = 1; i <= n; i++) lg2[i]++, nxt[i] = i+1, pre[i] = i-1;
nxt[n] = 0, lg2n = lg2[n];

for (int i = 1; i <= n; i++) update(i);
for (int i = 1; i <= n; i++) {
if (heap[i].size() == 0) puts("-1 0");
else {
int p = heap[i].top(), sz = heap[i].size();
printf("%d %d\n", p + bit.query(p), sz);
remove(p, lg2[i]);
}
}

return 0;
}

[CF1915F] Greetings

多测,给定 nn 个区间 [li,ri][l_i,r_i] 求出每个区间包含的区间个数之和,n2×105,109liri109\sum n\le 2\times 10^5,-10^9\le l_i\le r_i\le10^9

题目链接:CF1915F

首先排序用时间维度代替 rir_i 维度,然后在所有 rjrir_j\le r_i 的区间中查找 ljlil_j\ge l_i 的数量,先考虑如何将当前区间贡献出去,可以在 (,li](-\infin,l_i] 加一,这样查询 ll 的时候就会在当 llil\le l_i 的时候加上当前区间的贡献,这样是正确的。

将坐标离散化成 [1,2n][1,2n] 中的数字。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;

#define lowbit(x) ((x)&(-x))
#define int long long
#define eb emplace_back
const int N = 200010;
int n;

struct BIT {
int tr[N*2];

void init(int n) {
for (int i = 1; i <= n; i++) tr[i] = 0;
}

void add(int p, int k) {
for (; p < N*2; p += lowbit(p)) tr[p] += k;
}

int query(int p) {
int res = 0;
for (; p; p -= lowbit(p)) res += tr[p];
return res;
}
} bit;

struct Object {
int l, r;
bool operator<(const Object& obj) const {
if (r == obj.r) return l < obj.l;
return r < obj.r;
}
} obj[N];

int solve() {
cin >> n;
vector<int> nums;
nums.reserve(n*2);
for (int i = 1; i <= n; i++) cin >> obj[i].l >> obj[i].r, nums.eb(obj[i].l), nums.eb(obj[i].r);
sort(nums.begin(), nums.end());
sort(obj+1, obj+1+n);

int ans = 0;
bit.init(n*2);
for (int i = 1; i <= n; i++) {
int l = lower_bound(nums.begin(), nums.end(), obj[i].l) - nums.begin() + 1;
ans += bit.query(l);
bit.add(l+1, -1);
bit.add(1, 1);
}

return ans;
}

signed main() {
int t;
cin >> t;
while (t--) cout << solve() << '\n';
return 0;
}