快速入门
记公式:
Or operation: arr[k + step] += opt * arr[k] And operation: arr[k] += opt * arr[k + step] Xor operation: A = arr[k], B = arr[k + step] arr[k] = A + B, arr[k + step] = A - B with inverse-operation, the inv2 is needed: arr[k] /= 2, arr[k + step] /= 2;
// P4717.cpp // NTT; #include <bits/stdc++.h> using namespace std; const int MAX_N = (1 << 17) + 200, mod = 998244353; int n, Ai[MAX_N], Bi[MAX_N], C[MAX_N], A[MAX_N], B[MAX_N]; int quick_pow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } const int inv2 = quick_pow(2, mod - 2); void fwt_or(int *arr, int opt) { for (int step = 1; step < n; step <<= 1) for (int j = 0; j < n; j += (step << 1)) for (int k = j; k < j + step; k++) arr[k + step] = (1LL * arr[k + step] + mod + 1LL * arr[k] * opt) % mod; } void fwt_and(int *arr, int opt) { for (int step = 1; step < n; step <<= 1) for (int j = 0; j < n; j += (step << 1)) for (int k = j; k < j + step; k++) arr[k] = (1LL * arr[k] + mod + 1LL * arr[k + step] * opt) % mod; } void fwt_xor(int *arr, int opt) { for (int step = 1; step < n; step <<= 1) for (int j = 0; j < n; j += (step << 1)) for (int k = j; k < j + step; k++) { int A = arr[k], B = arr[k + step]; arr[k] = (1LL * A + B) % mod, arr[k + step] = (1LL * A + mod - B) % mod; if (opt == -1) arr[k] = 1LL * arr[k] * inv2 % mod, arr[k + step] = 1LL * arr[k + step] * inv2 % mod; } } int main() { scanf("%d", &n), n = 1 << n; for (int i = 0; i < n; i++) scanf("%d", &Ai[i]); for (int i = 0; i < n; i++) scanf("%d", &Bi[i]); memcpy(A, Ai, sizeof(A)), memcpy(B, Bi, sizeof(B)); fwt_or(A, 1), fwt_or(B, 1); for (int i = 0; i < n; i++) C[i] = 1LL * A[i] * B[i] % mod; fwt_or(C, -1); for (int i = 0; i < n; i++) printf("%d ", C[i]); puts(""); memcpy(A, Ai, sizeof(A)), memcpy(B, Bi, sizeof(B)); fwt_and(A, 1), fwt_and(B, 1); for (int i = 0; i < n; i++) C[i] = 1LL * A[i] * B[i] % mod; fwt_and(C, -1); for (int i = 0; i < n; i++) printf("%d ", C[i]); puts(""); memcpy(A, Ai, sizeof(A)), memcpy(B, Bi, sizeof(B)); fwt_xor(A, 1), fwt_xor(B, 1); for (int i = 0; i < n; i++) C[i] = 1LL * A[i] * B[i] % mod; fwt_xor(C, -1); for (int i = 0; i < n; i++) printf("%d ", C[i]); puts(""); return 0; }
应用
CF662C Binary Table
发现\(n\)非常的小,所以我们可以考虑把每一列压起来,然后统计相同列数的个数。如果我们枚举行的反转情况\(S\)(第一位为\(1\)代表第一行反转),然后把操作生效,再扫一遍对每一列进行讨论就可以得到局部最优解,再合成到全剧最优解。这样时间复杂度很高,但是我们可以看到潜在的优化空间。我们可以考虑用数学的方式归纳出来:
\[ \sum_{S = 0}^{2^n – 1} \sum_{x \text{xor} y = S} cnt[x] \cdot best[y] \]
其中,我们把原来为\(x\)的列变成\(y\),\(y\)的贡献自然是\(\min\{ \text{0的个数}, \text{1的个数} \}\)。
// CF662C.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 22, MAX_M = 1e5 + 200; int n, m; ll ai[1 << MAX_N], bi[1 << MAX_N], stats[MAX_M], poly_siz; void fwt_xor(ll *arr, int opt) { for (int step = 1; step < poly_siz; step <<= 1) for (int j = 0; j < poly_siz; j += (step << 1)) for (int k = j; k < j + step; k++) { ll A = arr[k], B = arr[k + step]; arr[k] = A + B, arr[k + step] = A - B; if (opt == -1) arr[k] >>= 1, arr[k + step] >>= 1; } } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) for (int j = 1, val; j <= m; j++) scanf("%1d", &val), stats[j] |= (val << (i - 1)); for (int i = 1; i <= m; i++) ai[stats[i]]++; for (int i = 0; i < (1 << n); i++) bi[i] = min(__builtin_popcount(i), n - __builtin_popcount(i)); poly_siz = 1LL << n; fwt_xor(ai, 1), fwt_xor(bi, 1); for (int i = 0; i < poly_siz; i++) ai[i] *= bi[i]; fwt_xor(ai, -1); ll ans = 0x3f3f3f3f3f3f3f3f; for (int i = 0; i < poly_siz; i++) ans = min(ans, ai[i]); printf("%lld\n", ans); return 0; }
LOJ 152 子集卷积
子集卷积就是:
\[ \sum_{S \subset U} f_S g_{U – S} \]
现在题目要求我们对于不同的\(U\)进行求解。我们可以考虑给\(f\)和\(g\)按位为\(1\)的个数进行分类,变成\(\{f_b\}\)和\(\{g_b\}\),再按位个数对\(n\)个多项式做 FWT_OR,然后就可以得到他们的子集展开。我们只需要枚举全集大小\(|U|\)和一个子集的大小\(|S|\),再把展开后的进行相乘存入\(|U|\)的答案中即可。输出答案要注意从相应个数的多项式中提取。
#pragma GCC optimize(2) // LOJ152.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = (1 << 22) + 2000, mod = 1e9 + 9; int n, fi[22][MAX_N], gi[22][MAX_N], poly_siz, cnt[MAX_N], ans[22][MAX_N], gans[MAX_N]; inline char nc() { static char buf[10000000], *p1 = buf, *p2 = buf; return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 10000000, stdin), p1 == p2) ? EOF : *p1++; } int read() { int x = 0, f = 1; char ch = nc(); while (!isdigit(ch)) { if (ch == '-') f = -1; ch = nc(); } while (isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = nc(); return x * f; } inline void fwt_or(int *arr, int opt) { for (int step = 1; step < poly_siz; step <<= 1) for (int j = 0; j < poly_siz; j += (step << 1)) for (int k = j; k < j + step; k++) arr[k + step] = (1LL * arr[k + step] + mod + 1LL * opt * arr[k]) % mod; } int main() { n = read(), poly_siz = (1 << n); for (int i = 0; i <= poly_siz; i++) cnt[i] = cnt[i >> 1] + (i & 1); for (int i = 0; i < poly_siz; i++) fi[cnt[i]][i] += read(); for (int i = 0; i < poly_siz; i++) gi[cnt[i]][i] += read(); for (int i = 0; i <= n; i++) fwt_or(fi[i], 1), fwt_or(gi[i], 1); for (int i = 0; i <= n; i++) for (int j = 0; j <= i; j++) for (int k = 0; k < poly_siz; k++) ans[i][k] = (1LL * ans[i][k] + 1LL * fi[i - j][k] * gi[j][k] % mod) % mod; for (int i = 0; i <= n; i++) fwt_or(ans[i], -1); for (int i = 0; i < poly_siz; i++) printf("%d ", ans[cnt[i]][i]); puts(""); return 0; }