快速沃尔什变换|FWT

快速入门

记公式:

Or operation:

    arr[k + step] += opt * arr[k]

And operation:

    arr[k] += opt * arr[k + step]

Xor operation:

    A = arr[k], B = arr[k + step]

    arr[k] = A + B, arr[k + step] = A - B

    with inverse-operation, the inv2 is needed:

    arr[k] /= 2, arr[k + step] /= 2;

// P4717.cpp
// NTT;
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = (1 << 17) + 200, mod = 998244353;

int n, Ai[MAX_N], Bi[MAX_N], C[MAX_N], A[MAX_N], B[MAX_N];

int quick_pow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

const int inv2 = quick_pow(2, mod - 2);

void fwt_or(int *arr, int opt)
{
    for (int step = 1; step < n; step <<= 1)
        for (int j = 0; j < n; j += (step << 1))
            for (int k = j; k < j + step; k++)
                arr[k + step] = (1LL * arr[k + step] + mod + 1LL * arr[k] * opt) % mod;
}

void fwt_and(int *arr, int opt)
{
    for (int step = 1; step < n; step <<= 1)
        for (int j = 0; j < n; j += (step << 1))
            for (int k = j; k < j + step; k++)
                arr[k] = (1LL * arr[k] + mod + 1LL * arr[k + step] * opt) % mod;
}

void fwt_xor(int *arr, int opt)
{
    for (int step = 1; step < n; step <<= 1)
        for (int j = 0; j < n; j += (step << 1))
            for (int k = j; k < j + step; k++)
            {
                int A = arr[k], B = arr[k + step];
                arr[k] = (1LL * A + B) % mod, arr[k + step] = (1LL * A + mod - B) % mod;
                if (opt == -1)
                    arr[k] = 1LL * arr[k] * inv2 % mod, arr[k + step] = 1LL * arr[k + step] * inv2 % mod;
            }
}

int main()
{
    scanf("%d", &n), n = 1 << n;
    for (int i = 0; i < n; i++)
        scanf("%d", &Ai[i]);
    for (int i = 0; i < n; i++)
        scanf("%d", &Bi[i]);

    memcpy(A, Ai, sizeof(A)), memcpy(B, Bi, sizeof(B));
    fwt_or(A, 1), fwt_or(B, 1);
    for (int i = 0; i < n; i++)
        C[i] = 1LL * A[i] * B[i] % mod;
    fwt_or(C, -1);
    for (int i = 0; i < n; i++)
        printf("%d ", C[i]);
    puts("");

    memcpy(A, Ai, sizeof(A)), memcpy(B, Bi, sizeof(B));
    fwt_and(A, 1), fwt_and(B, 1);
    for (int i = 0; i < n; i++)
        C[i] = 1LL * A[i] * B[i] % mod;
    fwt_and(C, -1);
    for (int i = 0; i < n; i++)
        printf("%d ", C[i]);
    puts("");

    memcpy(A, Ai, sizeof(A)), memcpy(B, Bi, sizeof(B));
    fwt_xor(A, 1), fwt_xor(B, 1);
    for (int i = 0; i < n; i++)
        C[i] = 1LL * A[i] * B[i] % mod;
    fwt_xor(C, -1);
    for (int i = 0; i < n; i++)
        printf("%d ", C[i]);
    puts("");
    return 0;
}

应用

CF662C Binary Table

发现\(n\)非常的小,所以我们可以考虑把每一列压起来,然后统计相同列数的个数。如果我们枚举行的反转情况\(S\)(第一位为\(1\)代表第一行反转),然后把操作生效,再扫一遍对每一列进行讨论就可以得到局部最优解,再合成到全剧最优解。这样时间复杂度很高,但是我们可以看到潜在的优化空间。我们可以考虑用数学的方式归纳出来:

\[ \sum_{S = 0}^{2^n – 1} \sum_{x \text{xor} y = S} cnt[x] \cdot best[y] \]

其中,我们把原来为\(x\)的列变成\(y\),\(y\)的贡献自然是\(\min\{ \text{0的个数}, \text{1的个数} \}\)。

// CF662C.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 22, MAX_M = 1e5 + 200;

int n, m;
ll ai[1 << MAX_N], bi[1 << MAX_N], stats[MAX_M], poly_siz;

void fwt_xor(ll *arr, int opt)
{
    for (int step = 1; step < poly_siz; step <<= 1)
        for (int j = 0; j < poly_siz; j += (step << 1))
            for (int k = j; k < j + step; k++)
            {
                ll A = arr[k], B = arr[k + step];
                arr[k] = A + B, arr[k + step] = A - B;
                if (opt == -1)
                    arr[k] >>= 1, arr[k + step] >>= 1;
            }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        for (int j = 1, val; j <= m; j++)
            scanf("%1d", &val), stats[j] |= (val << (i - 1));
    for (int i = 1; i <= m; i++)
        ai[stats[i]]++;
    for (int i = 0; i < (1 << n); i++)
        bi[i] = min(__builtin_popcount(i), n - __builtin_popcount(i));
    poly_siz = 1LL << n;
    fwt_xor(ai, 1), fwt_xor(bi, 1);
    for (int i = 0; i < poly_siz; i++)
        ai[i] *= bi[i];
    fwt_xor(ai, -1);
    ll ans = 0x3f3f3f3f3f3f3f3f;
    for (int i = 0; i < poly_siz; i++)
        ans = min(ans, ai[i]);
    printf("%lld\n", ans);
    return 0;
}

LOJ 152 子集卷积

子集卷积就是:

\[ \sum_{S \subset U} f_S g_{U – S} \]

现在题目要求我们对于不同的\(U\)进行求解。我们可以考虑给\(f\)和\(g\)按位为\(1\)的个数进行分类,变成\(\{f_b\}\)和\(\{g_b\}\),再按位个数对\(n\)个多项式做 FWT_OR,然后就可以得到他们的子集展开。我们只需要枚举全集大小\(|U|\)和一个子集的大小\(|S|\),再把展开后的进行相乘存入\(|U|\)的答案中即可。输出答案要注意从相应个数的多项式中提取。

#pragma GCC optimize(2)
// LOJ152.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = (1 << 22) + 2000, mod = 1e9 + 9;

int n, fi[22][MAX_N], gi[22][MAX_N], poly_siz, cnt[MAX_N], ans[22][MAX_N], gans[MAX_N];

inline char nc()
{
    static char buf[10000000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 10000000, stdin), p1 == p2) ? EOF : *p1++;
}

int read()
{
    int x = 0, f = 1;
    char ch = nc();
    while (!isdigit(ch))
    {
        if (ch == '-')
            f = -1;
        ch = nc();
    }
    while (isdigit(ch))
        x = (x << 3) + (x << 1) + ch - '0', ch = nc();
    return x * f;
}

inline void fwt_or(int *arr, int opt)
{
    for (int step = 1; step < poly_siz; step <<= 1)
        for (int j = 0; j < poly_siz; j += (step << 1))
            for (int k = j; k < j + step; k++)
                arr[k + step] = (1LL * arr[k + step] + mod + 1LL * opt * arr[k]) % mod;
}

int main()
{
    n = read(), poly_siz = (1 << n);
    for (int i = 0; i <= poly_siz; i++)
        cnt[i] = cnt[i >> 1] + (i & 1);
    for (int i = 0; i < poly_siz; i++)
        fi[cnt[i]][i] += read();
    for (int i = 0; i < poly_siz; i++)
        gi[cnt[i]][i] += read();
    for (int i = 0; i <= n; i++)
        fwt_or(fi[i], 1), fwt_or(gi[i], 1);
    for (int i = 0; i <= n; i++)
        for (int j = 0; j <= i; j++)
            for (int k = 0; k < poly_siz; k++)
                ans[i][k] = (1LL * ans[i][k] + 1LL * fi[i - j][k] * gi[j][k] % mod) % mod;
    for (int i = 0; i <= n; i++)
        fwt_or(ans[i], -1);
    for (int i = 0; i < poly_siz; i++)
        printf("%d ", ans[cnt[i]][i]);
    puts("");
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *