P5169:xtq的异或和 – 题解

主要思路

这题还挺好的。最后异或出来的路径是有链和环组成的。我们可以把链和环分开来求,因为从链上某点到环上某点之间的距离可以计算两次,所以不造成影响。所以我们可以把链做一个多项式,环做一个多项式,在做一遍 FWT_XOR 就完美了。

代码

// P5169.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200, MAX_E = 3e5 + 200, MAX_B = 262144, mod = 998244353;

int head[MAX_N], current, n, m, q, bas[30], dist[MAX_N], dist_t[MAX_B], loop[MAX_B];
bool vis[MAX_N];

struct edge
{
    int to, nxt, weight;
} edges[MAX_E << 1];

int quick_pow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

const int inv2 = quick_pow(2, mod - 2);

void addpath(int src, int dst, int weight)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    edges[current].weight = weight, head[src] = current++;
}

void insert(int x)
{
    for (int i = 20; i >= 0; i--)
        if (x & (1 << i))
            if (bas[i] == 0)
            {
                bas[i] = x;
                break;
            }
            else
                x ^= bas[i];
}

void prep(int dep, int pre)
{
    if (dep == 21)
        return (void)(loop[pre]++);
    prep(dep + 1, pre);
    if (bas[dep])
        prep(dep + 1, pre ^ bas[dep]);
}

void dfs(int u, int fa)
{
    vis[u] = true;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (vis[edges[i].to] == false)
            dist[edges[i].to] = dist[u] ^ edges[i].weight, dfs(edges[i].to, u);
        else
            insert(dist[edges[i].to] ^ dist[u] ^ edges[i].weight);
}

void fwt(int *arr, int opt)
{
    for (int step = 1; step < MAX_B; step <<= 1)
        for (int j = 0; j < MAX_B; j += (step << 1))
            for (int k = j; k < j + step; k++)
            {
                int A = arr[k], B = arr[k + step];
                arr[k] = (1LL * A + B) % mod, arr[k + step] = (1LL * A + mod - B) % mod;
                if (opt == -1)
                    arr[k] = 1LL * arr[k] * inv2 % mod, arr[k + step] = 1LL * arr[k + step] * inv2 % mod;
            }
}

int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d%d%d", &n, &m, &q);
    for (int i = 1, u, v, w; i <= m; i++)
        scanf("%d%d%d", &u, &v, &w), addpath(u, v, w), addpath(v, u, w);
    dfs(1, 0), prep(0, 0);
    for (int i = 1; i <= n; i++)
        dist_t[dist[i]]++;
    fwt(dist_t, 1), fwt(loop, 1);
    for (int i = 0; i < MAX_B; i++)
        loop[i] = 1LL * loop[i] * dist_t[i] % mod * dist_t[i] % mod;
    fwt(loop, -1);
    int x;
    while (q--)
        scanf("%d", &x), printf("%d\n", loop[x]);
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *