简介
树状数组可以看做简化版的线段树,它进行单点修改和区间查询的常数是比线段树更优的。
对于原数组 a [ n ] a[n] a [ n ] ,树状数组 c [ n ] c[n] c [ n ] 是一个等长的数组,并且对于任意 c [ i ] c[i] c [ i ] 表示以 a [ i ] a[i] a [ i ] 结尾且长度为 lowbit ( i ) \text{lowbit}(i) lowbit ( i ) 的区间和。
下面用 f ( x ) f(x) f ( x ) 代表 lowbit ( x ) \text{lowbit}(x) lowbit ( x ) ,L ( x ) , R ( x ) L(x),R(x) L ( x ) , R ( x ) 代表 x x x 覆盖的左闭右开区间,证明其正确性。
若 A 0 = k 2 i + 1 + 2 i A_0=k2^{i+1}+2^i A 0 = k 2 i + 1 + 2 i ,则 f ( A 0 ) = 2 i , L ( A 0 ) = A 0 − f ( A 0 ) = k 2 i + 1 , R ( A 0 ) = A 0 f(A_0)=2^i,L(A_0)=A_0-f(A_0)=k2^{i+1},R(A_0)=A_0 f ( A 0 ) = 2 i , L ( A 0 ) = A 0 − f ( A 0 ) = k 2 i + 1 , R ( A 0 ) = A 0 。
则 A 1 = A 0 + f ( A 0 ) = ( k + 1 ) 2 i + 1 A_1=A_0+f(A_0)=(k+1)2^{i+1} A 1 = A 0 + f ( A 0 ) = ( k + 1 ) 2 i + 1 ,那么 L ( A 1 ) = A 1 − f ( A 1 ) = ( k + 1 ) 2 i + 1 − f ( A 1 ) L(A_1)=A_1-f(A_1)=(k+1)2^{i+1}-f(A_1) L ( A 1 ) = A 1 − f ( A 1 ) = ( k + 1 ) 2 i + 1 − f ( A 1 ) 。
由于 f ( A 1 ) ≥ 2 i + 1 f(A_1)\ge 2^{i+1} f ( A 1 ) ≥ 2 i + 1 ,所以 L ( A 1 ) ≤ k 2 i + 1 = L ( A 0 ) L(A_1)\le k2^{i+1}=L(A_0) L ( A 1 ) ≤ k 2 i + 1 = L ( A 0 ) ,且 R ( A 1 ) = A 1 = ( k + 1 ) 2 i + 1 > A 0 = R ( A 0 ) R(A_1)=A_1=(k+1)2^{i+1}\gt A_0=R(A_0) R ( A 1 ) = A 1 = ( k + 1 ) 2 i + 1 > A 0 = R ( A 0 ) ,所以 A 1 A_1 A 1 一定覆盖 A 0 A_0 A 0 。
所以一个点的父节点是 x + f ( x ) x+f(x) x + f ( x ) ,它左边与它相邻且不是它子结点的点是 L ( x ) = x − f ( x ) L(x)=x-f(x) L ( x ) = x − f ( x ) 。
模板题
已知一个数列,你需要进行下面两种操作:
题目链接:P3374 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 #include <cstdio> #include <algorithm> using namespace std;const int N = 500010 ;int tr[N], n, m;#define lowbit(x) ((x)&(-x)) void add (int p, int v) { for (; p < N; p += lowbit (p)) tr[p] += v; } int query (int p) { int res = 0 ; for (; p; p -= lowbit (p)) res += tr[p]; return res; } int main () { scanf ("%d%d" , &n, &m); for (int i = 1 ; i <= n; i++) { int v; scanf ("%d" , &v); add (i, v); } while (m--) { int op, x, y; scanf ("%d%d%d" , &op, &x, &y); if (op == 1 ) add (x, y); else printf ("%d\n" , query (y) - query (x-1 )); } return 0 ; }
逆序对
定义:i < j i<j i < j 且 a i > a j a_i>a_j a i > a j 就称为一个逆序对,统计逆序对数目。
题目链接:P1908 。
本题可以用归并排序那样的分治算法,并且它更好,但是这里我们用树状数组来解决这个问题。
首先注意到值域比较大,所以需要离散化。当枚举到 a i a_i a i 时,我们需要知道前面有多少个数大于 a i a_i a i ,如果我们用树状数组来统计每个数字出现的次数,也就是求一下 [ a i + 1 , n ] [a_i+1,n] [ a i + 1 , n ] 的区间和,其中 n n n 是离散化后的最大值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 #include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std;typedef long long LL;const int N = 500010 ;int n, a[N];LL tr[N]; vector<int > nums; #define lowbit(x) ((x)&(-x)) void add (int p, int v) { for (; p < N; p += lowbit (p)) tr[p] += v; } LL query (int p) { LL res = 0 ; for (; p; p -= lowbit (p)) res += tr[p]; return res; } LL query (int l, int r) { return query (r) - query (l-1 ); } int find (int x) { return lower_bound (nums.begin (), nums.end (), x) - nums.begin () + 1 ; } int main () { scanf ("%d" , &n); for (int i = 1 ; i <= n; i++) { scanf ("%d" , &a[i]); nums.push_back (a[i]); } sort (nums.begin (), nums.end ()); nums.erase (unique (nums.begin (), nums.end ()), nums.end ()); LL res = 0 ; for (int i = 1 ; i <= n; i++) { int t = find (a[i]); res += query (t+1 , nums.size ()); add (t, 1 ); } printf ("%lld\n" , res); return 0 ; }
The Battle of Chibi (LIS)
简单题意:给 T T T 组数据,长度为 n n n 的数列 a a a 中,找出长度为 m m m 的严格上升子序列的个数,答案对 1e9+7
取模。
题目链接:UVA12983 ,UVA12983(Luogu) 。
这是一道 DP 题,但是可以用树状数组来加速。
状态表示 f ( i , j ) f(i,j) f ( i , j ) :长度为 i i i 以 a j a_j a j 结尾的最长上升子序列的个数。
状态转移:
f ( i , j ) = ∑ a k < a j , k < j f ( i − 1 , k ) f(i,j)=\sum_{a_k<a_j, k<j} f(i-1,k)
f ( i , j ) = a k < a j , k < j ∑ f ( i − 1 , k )
只要我们在循环到 j j j 时把之前的所有 f ( i − 1 , k ) f(i-1,k) f ( i − 1 , k ) 中 a k < a j a_k<a_j a k < a j 的值累加起来即可,这可以用树状数组优化,树状数组的索引是 a i a_i a i 的值,值是每个 f ( i − 1 , k ) f(i-1,k) f ( i − 1 , k ) 每次求的都是 [ 0 , a i − 1 ] [0,a_i-1] [ 0 , a i − 1 ] 间所有满足要求的值之和。
由于牵扯到值域的问题,这里就需要用到离散化。
初始化:f ( 1 , j ) = 1 f(1,j)=1 f ( 1 , j ) = 1
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 #include <cstdio> #include <cstring> #include <algorithm> #include <vector> using namespace std;const int N = 1010 , mod = 1e9 +7 ;vector<int > nums; int tr[N], n, m;int a[N], f[N][N];#define lowbit(x) ((x)&(-x)) inline void add (int p, int v) { for (; p < N; p += lowbit (p)) { tr[p] = (tr[p] + v) % mod; } } inline int query (int p) { int res = 0 ; for (; p; p -= lowbit (p)) res = (res + tr[p]) % mod; return res; } int find (int x) { return lower_bound (nums.begin (), nums.end (), x) - nums.begin () + 1 ; } int solve () { for (int j = 1 ; j <= n; j++) f[1 ][j] = 1 ; for (int i = 2 ; i <= m; i++) { memset (tr, 0 , sizeof tr); for (int j = 1 ; j <= n; j++) { f[i][j] = query (a[j]-1 ); add (a[j], f[i-1 ][j]); } } int res = 0 ; for (int j = 1 ; j <= n; j++) res = (res + f[m][j]) % mod; return res; } int main () { int T; scanf ("%d" , &T); for (int C = 1 ; C <= T; C++) { nums.clear (); scanf ("%d%d" , &n, &m); for (int i = 1 ; i <= n; i++) { scanf ("%d" , &a[i]); nums.push_back (a[i]); } sort (nums.begin (), nums.end ()); nums.erase (unique (nums.begin (), nums.end ()), nums.end ()); for (int i = 1 ; i <= n; i++) { a[i] = find (a[i]); } printf ("Case #%d: %d\n" , C, solve ()); } return 0 ; }
[THUPC2024 初赛] 二进制
小 F 给出了一个这里有一个长为 n ≤ 1 0 6 n\le 10^6 n ≤ 1 0 6 的二进制串 s s s ,下标从 1 1 1 到 n n n ,且 ∀ i ∈ [ 1 , n ] , s i ∈ { 0 , 1 } \forall i\in[1,n],s_i\in \{0,1\} ∀ i ∈ [ 1 , n ] , s i ∈ { 0 , 1 } ,他想要删除若干二进制子串。
具体的,小 F 做出了 n n n 次尝试。
在第 i ∈ [ 1 , n ] i\in[1,n] i ∈ [ 1 , n ] 次尝试中,他会先写出正整数 i i i 的二进制串表示 t t t (无前导零,左侧为高位,例如 10 10 10 可以写为 1010 1010 1010 )。
随后找到这个二进制表示 t t t 在 s s s 中从左到右 第一次 出现的位置,并删除这个串。
注意,删除后左右部分的串会拼接在一起 形成一个新的串 ,请注意新串下标的改变。
若当前 t t t 不在 s s s 中存在,则小 F 对串 s s s 不作出改变。
你需要回答每一次尝试中,t t t 在 s s s 中第一次出现的位置的左端点以及 t t t 在 s s s 中出现了多少次。
定义两次出现不同当且仅当出现的位置的左端点不同。
请注意输入输出效率。
题目链接:LOJ 6906 。
首先,对于一个数字 n n n ,它的二进制位数是 ⌊ log 2 n ⌋ + 1 \lfloor \log_2 n\rfloor + 1 ⌊ log 2 n ⌋ + 1 ,为了方便阅读,下文用 log n \log n log n 代替。
我们每次删除的都是 1 ∼ n 1\sim n 1 ∼ n 的二进制串,因此它的长度应该小于等于 log n \log n log n ,所以我们对于原串 s s s 的每位 i i i 都向后 log n \log n log n 个字符都扫描一遍,假设获得到的数字是 v v v ,那么就添加一个 v → i v\to i v → i 的映射关系。
删除的时候就在这个位置的前 log n \log n log n ,后 log v \log v log v 都更新一下。这样更新的原因是后 log v \log v log v 个元素是被删除掉的,所以需要删掉对应的映射关系;前 log n \log n log n 个向后关系有所改变,所以需要重新更新。
考虑如何处理坐标的变化,肯定是不能在映射里面修改的,这样复杂度太大了,所以用记录偏移量的方式,每次删除相当于在一个位置的后面的真实坐标都要减去 log v \log v log v ,所以这里开一个树状数组就可以了。
然后考虑这个映射关系如何维护,由于 v v v 于 n n n 数量级相同,所以可以直接开数组,然后这个数组内需要一个支持动态 add,delete,min,size
的数据结构,所以用可删堆来实现,相较于 std::set
常数要小一些。
删除之后再次访问的时候需要跳过删除掉的地方,这个操作可以用链表实现。由于更新的时候需要向前走,所以是双链表。
复杂度分析:删除的时候每次更新 O ( log n ) O(\log n) O ( log n ) 个元素,每次更新的复杂度是 O ( log 2 n ) O(\log^2 n) O ( log 2 n ) ,所以删除的复杂度是 O ( log 3 n ) O(\log^3 n) O ( log 3 n ) 。
但是一共只会执行 O ( n log n ) O(\frac{n}{\log n}) O ( l o g n n ) 次删除,所以最终的复杂度是 O ( n log 2 n ) O(n\log^2 n) O ( n log 2 n ) 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 #include <bits/stdc++.h> using namespace std;#define lowbit(x) ((x)&(-x)) const int N = 1200010 ;struct RemovableHeap { priority_queue<int , vector<int >, greater<int >> v, d; void remove (int x) { d.push (x); } int top () { while (d.size () && v.top () == d.top ()) v.pop (), d.pop (); return v.top (); } void push (int x) { v.push (x); } int size () { return v.size () - d.size (); } } heap[N]; int n, a[N], lg2[N], nxt[N], pre[N], lg2n;struct BIT { int tr[N]; int query (int p) { int res = 0 ; for (; p; p -= lowbit (p)) res += tr[p]; return res; } void add (int p, int k) { for (; p < N; p += lowbit (p)) tr[p] += k; } } bit; void update (int s, bool add = true ) { if (a[s] == 0 ) return ; for (int i = s, now = 0 , cnt = 0 ; i && cnt < lg2n; i = nxt[i], cnt++) { now = now << 1 | a[i]; if (add) heap[now].push (s); else heap[now].remove (s); } } void remove (int s, int bits) { int ne = s, raws = s; bit.add (s, -bits); for (int i = 0 ; i < bits; i++) { update (ne, false ); ne = nxt[ne]; } pre[ne] = pre[s]; for (int i = 0 ; s && i < lg2n; i++) { s = pre[s]; update (s, false ); } s = raws; if (pre[s]) nxt[pre[s]] = ne; for (int i = 0 ; s && i < lg2n; i++) { s = pre[s]; update (s); } } int main () { static char s[N]; scanf ("%d%s" , &n, s+1 ); for (int i = 1 ; i <= n; i++) a[i] = s[i] ^ 48 ; lg2[1 ] = 0 ; for (int i = 2 ; i <= n; i++) lg2[i] = lg2[i>>1 ] + 1 ; for (int i = 1 ; i <= n; i++) lg2[i]++, nxt[i] = i+1 , pre[i] = i-1 ; nxt[n] = 0 , lg2n = lg2[n]; for (int i = 1 ; i <= n; i++) update (i); for (int i = 1 ; i <= n; i++) { if (heap[i].size () == 0 ) puts ("-1 0" ); else { int p = heap[i].top (), sz = heap[i].size (); printf ("%d %d\n" , p + bit.query (p), sz); remove (p, lg2[i]); } } return 0 ; }
[CF1915F] Greetings
多测,给定 n n n 个区间 [ l i , r i ] [l_i,r_i] [ l i , r i ] 求出每个区间包含的区间个数之和,∑ n ≤ 2 × 1 0 5 , − 1 0 9 ≤ l i ≤ r i ≤ 1 0 9 \sum n\le 2\times 10^5,-10^9\le l_i\le r_i\le10^9 ∑ n ≤ 2 × 1 0 5 , − 1 0 9 ≤ l i ≤ r i ≤ 1 0 9 。
题目链接:CF1915F 。
首先排序用时间维度代替 r i r_i r i 维度,然后在所有 r j ≤ r i r_j\le r_i r j ≤ r i 的区间中查找 l j ≥ l i l_j\ge l_i l j ≥ l i 的数量,先考虑如何将当前区间贡献出去,可以在 ( − ∞ , l i ] (-\infin,l_i] ( − ∞ , l i ] 加一,这样查询 l l l 的时候就会在当 l ≤ l i l\le l_i l ≤ l i 的时候加上当前区间的贡献,这样是正确的。
将坐标离散化成 [ 1 , 2 n ] [1,2n] [ 1 , 2 n ] 中的数字。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 #include <bits/stdc++.h> using namespace std;#define lowbit(x) ((x)&(-x)) #define int long long #define eb emplace_back const int N = 200010 ;int n;struct BIT { int tr[N*2 ]; void init (int n) { for (int i = 1 ; i <= n; i++) tr[i] = 0 ; } void add (int p, int k) { for (; p < N*2 ; p += lowbit (p)) tr[p] += k; } int query (int p) { int res = 0 ; for (; p; p -= lowbit (p)) res += tr[p]; return res; } } bit; struct Object { int l, r; bool operator <(const Object& obj) const { if (r == obj.r) return l < obj.l; return r < obj.r; } } obj[N]; int solve () { cin >> n; vector<int > nums; nums.reserve (n*2 ); for (int i = 1 ; i <= n; i++) cin >> obj[i].l >> obj[i].r, nums.eb (obj[i].l), nums.eb (obj[i].r); sort (nums.begin (), nums.end ()); sort (obj+1 , obj+1 +n); int ans = 0 ; bit.init (n*2 ); for (int i = 1 ; i <= n; i++) { int l = lower_bound (nums.begin (), nums.end (), obj[i].l) - nums.begin () + 1 ; ans += bit.query (l); bit.add (l+1 , -1 ); bit.add (1 , 1 ); } return ans; } signed main () { int t; cin >> t; while (t--) cout << solve () << '\n' ; return 0 ; }