位运算卷积
回顾 FFT,它要求的其实是:
Ck=i+j=k∑AiBj
将加号换成位运算符就是 FWT 了。
或卷积
首先看较为简单的或,我们这里把二进制数看作集合:
Ck=i∪j=k∑AiBj
将多项式看作向量,我们需要找到一个线性映射 FWT(X) 满足:
FWT(C)k=FWT(A)kFWT(B)k
并且这个线性映射需要做到可逆,这里直接给出式子了:
FWT(A)k=i⊂k∑Ai
因此有:
FWT(A)kFWT(B)k=i⊂k∑Aij⊂k∑Bj=(i∪j)⊂k∑AiBj=x⊂k∑(i∪j)=x∑AiBj=FWT(C)k
我们考虑如何进行 FWT 变换,假设多项式共有有 2n 项,前半部分用 A0 表示,后边部分用 A1 表示,考虑分治:
FWT(A)=merge(FWT(A0),FWT(A0)+FWT(A1))
由于后半部分每个对应位置都相当于前半部分对应位置的二进制位最前面添上一个 1,所以后半部分就可以直接向量相加,补充上缺失的部分。merge 就是把两个向量首尾相接拼上。
因为有这个式子,所以我们进行逆变换的时候,直接这样:
IFWT(A)=merge(IFWT(A0),IFWT(A1)−IFWT(A0))
这个是因为 FWT 是线性变换,于是 FWT(A0)+FWT(A1)=FWT(A0+A1),所以在逆变换的时候可以直接减掉。
和卷积
类似地,定义一下卷积:
Ck=i∩j=k∑AiBj
然后定义一下 FWT 这个变换:
FWT(A)k=i⊃k∑Ai
因此有:
FWT(A)kFWT(B)k=i⊃k∑Aij⊃k∑Bj=(i∩j)⊃k∑AiBj=x⊃k∑i∩j=x∑AiBj=FWT(C)k
然后看如何分治去求,下标 k 如果包含在前半部分的某个下 i 里面,它也一定包含在后半部分的里边;反之就不成立了,所以可以这样写:
FWT(A)IFWT(A)=merge(FWT(A0)+FWT(A1),FWT(A1))=merge(IFWT(A0)−IFWT(A1),IFWT(A1))
异或卷积
处理起来不太一样,主要用到了 popcount 的性质。
popcount(i&k)+popcount(j&k)≡popcount((i⊕j)&k)(mod2)
分每一位去看:
- 若 k 的这一位为 0,那么最终是不产生贡献的。
- 若 k 的这一位为 1,并且 i,j 这一位相同,左边的贡献为 2,右边的贡献为 0,是同余的。
- 若 k 的这一位为 1,并且 i,j 这一位不同,左边的贡献为 1,右边的贡献为 1,是同余的。
看到同余于 2,想到一定会有:
(−1)LHS=(−1)RHS
所以定义线性变换:
FWT(A)k=i∑(−1)popcount(i&k)Ai
因此有:
FWT(A)kFWT(B)k=i,j∑(−1)popcount(i&k)Ai(−1)popcount(j&k)Bj=i,j∑(−1)popcount((i⊕j)&k)AiBj=FWT(C)k
这和前面两个不同的地方在于每一项都是对整个向量所有值进行加权求和,所以最后根据加法交换律就不用在乎这一项到底是啥,只要前面的系数对了就行。
同样可以进行分治去求,由于前半部分的 k 的最高位为 0,所以 i 是没有贡献到的,后半部分贡献到了:
FWT(A)=merge(FWT(A0)+FWT(A1),FWT(A0)−FWT(A1))IFWT(A)=merge(2IFWT(A0)+IFWT(A1),2IFWT(A0)−IFWT(A1))
接下来就是代码实现了。
代码
来自 P4717 的模板题,如果写个自动取模的 mint
可能会好看一点,我这里就不写了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| #include <bits/stdc++.h> using namespace std;
const int N = 18, INV2 = 499122177, P = 998244353; int n, a[1<<N], b[1<<N], A[1<<N], B[1<<N], C[1<<N];
void OR(int A[], int inv) { for (int len = 2, half = 1; len <= 1 << n; len <<= 1, half <<= 1) { for (int i = 0; i < 1 << n; i += len) { for (int j = 0; j < half; j++) { A[i+j+half] = (A[i+j+half] + A[i+j] * inv) % P; if (A[i+j+half] < 0) A[i+j+half] += P; } } } }
void AND(int A[], int inv) { for (int len = 2, half = 1; len <= 1 << n; len <<= 1, half <<= 1) { for (int i = 0; i < 1 << n; i += len) { for (int j = 0; j < half; j++) { A[i+j] = (A[i+j] + A[i+j+half] * inv) % P; if (A[i+j] < 0) A[i+j] += P; } } } }
void XOR(int A[]) { for (int len = 2, half = 1; len <= 1 << n; len <<= 1, half <<= 1) for (int i = 0; i < 1 << n; i += len) for (int j = 0; j < half; j++) { int A0 = A[i+j], A1 = A[i+j+half]; A[i+j] = (A0 + A1) % P; A[i+j+half] = (A0 - A1) % P; if (A[i+j+half] < 0) A[i+j+half] += P; } }
void IXOR(int A[]) { for (int len = 2, half = 1; len <= 1 << n; len <<= 1, half <<= 1) for (int i = 0; i < 1 << n; i += len) for (int j = 0; j < half; j++) { int A0 = A[i+j], A1 = A[i+j+half]; A[i+j] = 1LL * (A0 + A1) * INV2 % P; A[i+j+half] = 1LL * (A0 - A1) * INV2 % P; if (A[i+j+half] < 0) A[i+j+half] += P; } }
void cpy() { for (int i = 0; i < 1 << n; i++) A[i] = a[i]; for (int i = 0; i < 1 << n; i++) B[i] = b[i]; }
void mul() { for (int i = 0; i < 1 << n; i++) C[i] = 1LL * A[i] * B[i] % P; }
void out() { for (int i = 0; i < 1 << n; i++) printf("%d ", C[i]); puts(""); }
int main() { scanf("%d", &n); for (int i = 0; i < 1 << n; i++) scanf("%d", &a[i]); for (int i = 0; i < 1 << n; i++) scanf("%d", &b[i]);
cpy(); OR(A, 1), OR(B, 1); mul(); OR(C, -1); out();
cpy(); AND(A, 1), AND(B, 1); mul(); AND(C, -1); out();
cpy(); XOR(A), XOR(B); mul(); IXOR(C); out();
return 0; }
|