题目

CF2144E1CF2144E2

思路1

我们需要把从第一个元素开始的最长上升序列求出来,设这样的序列是 S1,S2,,SkS_1,S_2,\dots,S_k

为了方便,我们可以再所有元素之前插入一个 00,使得有效的子序列都以第一个元素开头。

由于左边的上升序列是不可能看到右边的元素的,所以我们不妨把左右分开考虑。并且,右边的答案可以通过 reverse 后按照同样的方式求出左边的答案之后再次 reverse 得到,因此我们可以只考虑左边的做法。

dpidp_i 是以 aia_i 结尾的答案,preipre_{i} 是序列 {Sn}\{S_n\}aia_i 的前一项。由于我们插入了 00 这个元素在开头,所以任意一个 i2i\ge 2aia_ipreipre_i 都是一定存在的。

那么,在 preiaipre_i\sim a_i 这些元素中,我们可以自由选择 akpreia_k\le pre_i 的元素是否加入这个子序列,所以递推式会是:

dpi={j<i,aj=prei2k=j+1i1[akprei]dpjai{Sn}0ai∉{Sn}dp_i=\begin{cases} \sum_{j<i,a_j=pre_i}2^{\sum_{k=j+1}^{i-1}[a_k\le pre_i]}dp_j\quad a_i\in \{S_n\}\\ 0\quad a_i\not\in\{S_n\} \end{cases}

其中 [p(k)][p(k)] 是艾佛森括号。设左边做出来的是 dp1idp_{1i},右边做出来的是 dp2idp_{2i}

那么还有一个问题,设 m=Skm=S_k(即上升序列的最后一项),{an}\{a_n\} 这个序列可能有多个 ai=ma_i=m,那么我们该如何计数?

如果选定的子序列只有一个 ai=ma_i=m,贡献显然是 dp1i×dp2idp_{1i}\times dp_{2i};如果存在多个,设 ai=aj=ma_i=a_j=m 是最左边和最右边的那两个,那么 i,ji,j 中间的数字显然可以随意选择,因此贡献是 dp1i×dp2j×2ji1dp_{1i}\times dp_{2j}\times 2^{j-i-1}

实现1

这是 Easy Version 的代码,注意上文中的 aia_i 对应离散化后的数组是 bib_i

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define int long long
#define lowbit(x) ((x)&(-x))
const int P = 998244353, N = 5010;

int qpow(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = res * a % P;
a = a * a % P;
k >>= 1;
}
return res;
}

int a[N], b[N], dp1[N], dp2[N], nxt[N];

// 离散化后 b[i] 的长度
int p;

void dodp(int* dp) {
// 记录上升序列的值, mp1 是把元素映射到编号, mp2 是把编号映射到元素
map<int, int> mp1, mp2;
for (int i = 1; i <= p; i++) {
nxt[i] = dp[i] = 0; // clear
for (int j = i+1; j <= p; j++) {
if (b[j] > b[i]) {
nxt[i] = j;
break;
}
}
}

int cnt = 0;
for (int i = 1; i; i = nxt[i]) {
cnt++;
mp1[b[i]] = cnt;
mp2[cnt] = b[i];
}

dp[1] = 1;
for (int i = 2; i <= p-1; i++) {
if (mp1.count(b[i]) == 0) continue;
int pre = mp2[mp1[b[i]]-1];
int fac = 1;
for (int j = i-1; j >= 1; j--) {
if (b[j] == pre) dp[i] = (dp[i] + fac * dp[j]) % P;
if (b[j] <= pre) fac = fac * 2 % P;
}
}
}

void solve() {
int n, m;
cin >> n;
vector<int> nums;
for (int i = 1; i <= n; i++) {
cin >> a[i];
nums.pb(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
m = nums.size();

p = 0;
b[++p] = 0;
for (int i = 1; i <= n; i++) {
int to = lower_bound(nums.begin(), nums.end(), a[i]) - nums.begin() + 1;
b[++p] = to;
}
b[++p] = 0;

dodp(dp1);
reverse(b+1, b+1+p);
dodp(dp2);
reverse(dp2+1, dp2+1+p);
reverse(b+1, b+1+p);

int ans = 0;
for (int i = 1; i <= p; i++) {
for (int j = i; j <= p; j++) {
if (b[i] != m || b[j] != m) continue;

ans = (ans + dp1[i] * dp2[j] % P * qpow(2, max(j-i-1, 0ll))) % P;
}
}

cout << ans << '\n';
}

signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int T;
cin >> T;
while (T--) solve();
return 0;
}

思路2

接着上文的 dpdp 方法,我们发现导致复杂度为 O(n2)O(n^2) 的原因主要是:

  1. 求上升序列时用了二重循环。
  2. 状态转移时用了二重循环。

对于第一个问题,有很多种方法可以优化,我这里选择开一个线段树。

实际上树状数组应当也可以。但是我的实现中需要求的是对值域的后缀最小值,而树状数组只支持前缀的操作,所以你需要对下标进行一次变换。

第二个问题需要具体分析这个转移方程。

首先我们发现,转移的时候只需要 aj=preia_j=pre_i 的这些 jj 点,所以我们可以定义序列 FvF_v

Fv=j<i,aj=v2k=j+1i1[akv]dpjF_v=\sum_{j<i,a_j=v}2^{\sum_{k=j+1}^{i-1}[a_k\le v]}dp_j

这样的话,求 dpidp_i 可以直接查表:

dpi={Fpreiai{Sn}0ai∉{Sn}dp_i=\begin{cases} F_{pre_i}\quad a_i\in \{S_n\}\\ 0\quad a_i\not\in\{S_n\} \end{cases}

主要问题是如何更新序列 FvF_v。求完一个 dpidp_i 之后,首先我们需要关注 2k=j+1i1[akv]2^{\sum_{k=j+1}^{i-1}[a_k\le v]} 这一项。具体地说,是:

Fv2Fv where aivF_v\gets 2F_v \text{ where }a_i\le v

其次,需要把这个新的 dpidp_i 加到 FaiF_{a_i} 去,也就是:

FaiFai+dpiF_{a_i}\gets F_{a_i}+dp_i

可以看出,这是一个经典的线段树区间加乘可以维护的问题。

对于答案中 j<ij<i 的部分,我们可以写出它的式子:

ai=mj<i,aj=mdp2idp1j2ij1=ai=mdp2i2i(j<i,aj=mdp1j2j1)\sum_{a_i=m} \sum_{j<i,a_j=m} dp_{2i}dp_{1j}2^{i-j-1}=\sum_{a_i=m}dp_{2i}2^i\left(\sum_{j<i,a_j=m}dp_{1j}2^{-j-1}\right)

再加上 j=ij=i 的部分就是最终答案。

实现2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define int long long
#define lowbit(x) ((x)&(-x))
const int P = 998244353, N = 300010, INF = 1e18;

int qpow(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = res * a % P;
a = a * a % P;
k >>= 1;
}
return res;
}

int a[N], b[N], dp1[N], dp2[N], nxt[N];

// length of b and dp, number of values
int p, m;
#define ls (u<<1)
#define rs (u<<1|1)
#define Mid ((L+R) >> 1)
struct SGT {
int mn[N<<2];

void build(int u, int L, int R) {
if (L == R) return mn[u] = INF, void();
build(ls, L, Mid);
build(rs, Mid+1, R);
pushup(u);
}

void pushup(int u) {
mn[u] = min(mn[ls], mn[rs]);
}

void modify(int u, int p, int k, int L, int R) {
if (L == R) return mn[u] = k, void();
if (p <= Mid) modify(ls, p, k, L, Mid);
else modify(rs, p, k, Mid+1, R);
pushup(u);
}

int query(int u, int l, int r, int L, int R) {
if (l <= L && R <= r) return mn[u];
int res = INF;
if (l <= Mid) res = query(ls, l, r, L, Mid);
if (Mid+1 <= r) res = min(res, query(rs, l, r, Mid+1, R));
return res;
}
} t;

struct SGT2 {
int add[N<<2], mul[N<<2];

void build(int u, int L, int R) {
add[u] = 0;
mul[u] = 1;
if (L == R) {
return;
}
build(ls, L, Mid);
build(rs, Mid+1, R);
}

void spread(int u, int m, int a) {
mul[u] = mul[u] * m % P;
add[u] = (m * add[u] + a) % P;
}

void pushdown(int u) {
spread(ls, mul[u], add[u]);
spread(rs, mul[u], add[u]);
mul[u] = 1;
add[u] = 0;
}

int query(int u, int p, int L, int R) {
if (L == R) return add[u];
pushdown(u);
if (p <= Mid) return query(ls, p, L, Mid);
else return query(rs, p, Mid+1, R);
}

void modify(int u, int l, int r, int m, int a, int L, int R) {
if (l <= L && R <= r) return spread(u, m, a);
pushdown(u);
if (l <= Mid) modify(ls, l, r, m, a, L, Mid);
if (Mid+1 <= r) modify(rs, l, r, m, a, Mid+1, R);
}
} F;

void dodp(int dp[]) {
map<int, int> mp1, mp2;

t.build(1, 0, m);
for (int i = p; i >= 1; i--) {
dp[i] = 0;
nxt[i] = t.query(1, b[i]+1, m, 0, m);
t.modify(1, b[i], i, 0, m);
}

int cnt = 0;
for (int i = 1; i != INF; i = nxt[i]) {
cnt++;
mp1[b[i]] = cnt;
mp2[cnt] = b[i];
}

dp[1] = 1;
F.build(1, 0, m);
F.modify(1, 0, 0, 1, 1, 0, m);
for (int i = 2; i <= p-1; i++) {
if (mp1.count(b[i]) == 0) {
F.modify(1, b[i], m, 2, 0, 0, m);
continue;
}
int pre = mp2[mp1[b[i]]-1];
dp[i] = F.query(1, pre, 0, m);
F.modify(1, b[i], m, 2, 0, 0, m);
F.modify(1, b[i], b[i], 1, dp[i], 0, m);
}
}

void solve() {
int n;
cin >> n;
vector<int> nums;
for (int i = 1; i <= n; i++) {
cin >> a[i];
nums.pb(a[i]);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
m = nums.size();

p = 0;
b[++p] = 0;
for (int i = 1; i <= n; i++) {
int to = lower_bound(nums.begin(), nums.end(), a[i]) - nums.begin() + 1;
b[++p] = to;
}
b[++p] = 0;

dodp(dp1);
reverse(b+1, b+1+p);
dodp(dp2);
reverse(dp2+1, dp2+1+p);
reverse(b+1, b+1+p);

int ans = 0;
int now = 0;
for (int i = 2; i <= p-1; i++) {
if (b[i] != m) continue;
ans = (ans + dp1[i] * dp2[i]) % P;
ans = (ans + dp2[i] * now % P * qpow(2, i)) % P;
now = (now + dp1[i] * qpow(2, P-2-i)) % P;
}

cout << ans << '\n';
}

signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int T;
cin >> T;
while (T--) solve();
return 0;
}