Loading [MathJax]/extensions/tex2jax.js

BZOJ 3456:城市规划题解

题面

注意:这道题在 BZOJ 上是权限题,我是在 JZOJ 上做的。

刚刚解决完电力网络的问题,阿狸又被领导的任务给难住了。

刚才说过, 阿狸的国家有\(n\)个城市,现在国家需要在某些城市对之间建立一些贸易路线, 使得整个国家的任意两个城市都直接或间接的连通。

为了省钱,每两个城市之间最多只能有一条直接的贸易路径。对于两个建立路线的方案, 如果存在一个城市对,在两个方案中是否建立路线不一样,那么这两个方案就是不同的,否则就是相同的。现在你需要求出一共有多少不同的方案。

好了,这就是困扰阿狸的问题。换句话说,你需要求出\(n\)个点的简单(无重边无自环)无向连通图数目。

由于这个数字可能非常大,你只需要输出方案数 mod 1004535809(479 * 2 ^21 + 1) 即可。

主要思路

其实这道题的式子很好推,难就难在要多项式求逆之类的:

\[ f(n) = g(n) – \sum_{i = 1}^{n – 1} f(i) {n – 1 \choose i – 1} g(n – i) \]

其中,\(f(n)\)是大小为\(n\)的无向连通图的个数,而\(g(n)\)为无向图的个数(也就是说,存在不联通的情况),\(g(n) = 2^{\frac{n(n-1)}{2}}\)。稍稍变换一下:

\[ \sum_{i = 1}^n f(i) {n – 1 \choose i – 1} g(n – i) = g(n) \]

拆开组合数:

\[ \sum_{i = 1}^n \frac{f(i)g(n-i)(n-1)!}{(i-1)!(n-i)!} = g(n) \\ \sum_{i = 1}^n \frac{f(i)g(n-i)}{(i-1)!(n-i)!} = \frac{g(n)}{(n-1)!} \]

拆开一些部分:

\[ \sum_{i = 1}^n \frac{f(i)}{(i-1)!} \frac{g(n-i)}{(n-i)!} = \frac{g(n)}{(n-1)!} \]

发现是卷积的形式。考虑设置生成函数:

\[ G(x) = \sum_{i = 0}^{\infty}\frac{g(i)}{i!} x^n \\ C(x) = \sum_{i = 0}^{\infty} \frac{g(i)}{(i-1)!} x^n \\ F(x) = \sum_{i = 0}^{\infty} \frac{f(i)}{(i-1)!} x^n \]

根据题意可以写成:

\[ F(x)G(x) \equiv C(x) \ (mod \ x^n) \]

多项式在解题中的好处就是:你可以忽略掉多项式中的变量,只关注其系数,然后用多项式的角度对这些系数的关系进行描述,然后用多项式的算法来解题。

\[ F(x) \equiv C(x)G^{-1}(x) \ (mod \ x^n) \]

然后\(f(n)\)最终就等于右边多项式的第\(n\)个系数再乘上一个\((n-1)!\)就行了。

代码

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// C.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod = 1004535809, G = 3, MAX_N = 1 << 18;
int n, rev[MAX_N], Gi;
ll tmp[MAX_N], level[MAX_N], level_inv[MAX_N], gx[MAX_N], CX[MAX_N], GX[MAX_N], invG[MAX_N];
ll quick_pow(ll bas, ll tim)
{
ll ans = 1;
bas %= mod;
while (tim)
{
if (tim & 1)
ans = ans * bas % mod;
bas = bas * bas % mod;
tim >>= 1;
}
return ans;
}
inline void ntt(ll *arr, int limit, int dft)
{
for (int i = 0; i < limit; i++)
if (i < rev[i])
swap(arr[i], arr[rev[i]]);
for (int step = 1; step < limit; step <<= 1)
{
ll omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1));
for (int j = 0; j < limit; j += (step << 1))
{
ll omega_nk = 1;
for (int k = j; k < j + step; k++, omega_nk = (omega_nk * omega_n % mod))
{
ll x = arr[k], y = omega_nk * arr[k + step] % mod;
arr[k] = (x + y) % mod, arr[k + step] = (x - y + mod) % mod;
}
}
}
if (dft != 1)
{
ll inv = quick_pow(limit, mod - 2);
for (int i = 0; i < limit; i++)
arr[i] = (arr[i] * inv % mod);
}
}
void poly_inverse(int deg, ll *a, ll *b)
{
if (deg == 1)
return (void)(b[0] = quick_pow(a[0], mod - 2));
poly_inverse((deg + 1) >> 1, a, b);
ll limit = 2, mx_bit = 1;
while ((deg << 1) > limit)
limit <<= 1, mx_bit++;
for (int i = 1; i < limit; i++)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1)));
for (int i = 0; i < deg; i++)
tmp[i] = a[i];
for (int i = deg; i < limit; i++)
tmp[i] = 0;
ntt(tmp, limit, 1), ntt(b, limit, 1);
for (int i = 0; i < limit; i++)
b[i] = 1LL * ((2LL - tmp[i] * b[i] % mod + mod) % mod) * b[i] % mod;
ntt(b, limit, -1);
for (int i = deg; i < limit; i++)
b[i] = 0;
}
int main()
{
Gi = quick_pow(G, mod - 2);
scanf("%d", &n);
level[1] = level[0] = 1, level_inv[0] = 1;
for (int i = 2; i <= n; i++)
level[i] = 1LL * level[i - 1] * i % mod;
level_inv[n] = quick_pow(level[n], mod - 2);
for (int i = n - 1; i >= 1; i--)
level_inv[i] = 1LL * level_inv[i + 1] * (i + 1) % mod;
gx[1] = gx[0] = 1;
for (int i = 2; i <= n; i++)
gx[i] = quick_pow(2, ((1LL * i * (i - 1)) >> 1) % (mod - 1));
for (int i = 0; i <= n; i++)
GX[i] = 1LL * gx[i] * level_inv[i] % mod;
for (int i = 1; i <= n; i++)
CX[i] = 1LL * gx[i] * level_inv[i - 1] % mod;
poly_inverse(n + 1, GX, invG);
ll limit = 2, mx_bit = 1;
while (((n + 1) << 1) > limit)
limit <<= 1, mx_bit++;
for (int i = 1; i < limit; i++)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1)));
ntt(invG, limit, 1), ntt(CX, limit, 1);
for (int i = 0; i < limit; i++)
invG[i] = (invG[i] * CX[i]) % mod;
ntt(invG, limit, -1);
ll ans = invG[n] * quick_pow(level_inv[n - 1], mod - 2) % mod;
while (ans < 0)
ans += mod;
ans %= mod;
printf("%lld", ans);
return 0;
}
// C.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int mod = 1004535809, G = 3, MAX_N = 1 << 18; int n, rev[MAX_N], Gi; ll tmp[MAX_N], level[MAX_N], level_inv[MAX_N], gx[MAX_N], CX[MAX_N], GX[MAX_N], invG[MAX_N]; ll quick_pow(ll bas, ll tim) { ll ans = 1; bas %= mod; while (tim) { if (tim & 1) ans = ans * bas % mod; bas = bas * bas % mod; tim >>= 1; } return ans; } inline void ntt(ll *arr, int limit, int dft) { for (int i = 0; i < limit; i++) if (i < rev[i]) swap(arr[i], arr[rev[i]]); for (int step = 1; step < limit; step <<= 1) { ll omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1)); for (int j = 0; j < limit; j += (step << 1)) { ll omega_nk = 1; for (int k = j; k < j + step; k++, omega_nk = (omega_nk * omega_n % mod)) { ll x = arr[k], y = omega_nk * arr[k + step] % mod; arr[k] = (x + y) % mod, arr[k + step] = (x - y + mod) % mod; } } } if (dft != 1) { ll inv = quick_pow(limit, mod - 2); for (int i = 0; i < limit; i++) arr[i] = (arr[i] * inv % mod); } } void poly_inverse(int deg, ll *a, ll *b) { if (deg == 1) return (void)(b[0] = quick_pow(a[0], mod - 2)); poly_inverse((deg + 1) >> 1, a, b); ll limit = 2, mx_bit = 1; while ((deg << 1) > limit) limit <<= 1, mx_bit++; for (int i = 1; i < limit; i++) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1))); for (int i = 0; i < deg; i++) tmp[i] = a[i]; for (int i = deg; i < limit; i++) tmp[i] = 0; ntt(tmp, limit, 1), ntt(b, limit, 1); for (int i = 0; i < limit; i++) b[i] = 1LL * ((2LL - tmp[i] * b[i] % mod + mod) % mod) * b[i] % mod; ntt(b, limit, -1); for (int i = deg; i < limit; i++) b[i] = 0; } int main() { Gi = quick_pow(G, mod - 2); scanf("%d", &n); level[1] = level[0] = 1, level_inv[0] = 1; for (int i = 2; i <= n; i++) level[i] = 1LL * level[i - 1] * i % mod; level_inv[n] = quick_pow(level[n], mod - 2); for (int i = n - 1; i >= 1; i--) level_inv[i] = 1LL * level_inv[i + 1] * (i + 1) % mod; gx[1] = gx[0] = 1; for (int i = 2; i <= n; i++) gx[i] = quick_pow(2, ((1LL * i * (i - 1)) >> 1) % (mod - 1)); for (int i = 0; i <= n; i++) GX[i] = 1LL * gx[i] * level_inv[i] % mod; for (int i = 1; i <= n; i++) CX[i] = 1LL * gx[i] * level_inv[i - 1] % mod; poly_inverse(n + 1, GX, invG); ll limit = 2, mx_bit = 1; while (((n + 1) << 1) > limit) limit <<= 1, mx_bit++; for (int i = 1; i < limit; i++) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1))); ntt(invG, limit, 1), ntt(CX, limit, 1); for (int i = 0; i < limit; i++) invG[i] = (invG[i] * CX[i]) % mod; ntt(invG, limit, -1); ll ans = invG[n] * quick_pow(level_inv[n - 1], mod - 2) % mod; while (ans < 0) ans += mod; ans %= mod; printf("%lld", ans); return 0; }
// C.cpp
#include <bits/stdc++.h>

#define ll long long

using namespace std;

const int mod = 1004535809, G = 3, MAX_N = 1 << 18;

int n, rev[MAX_N], Gi;
ll tmp[MAX_N], level[MAX_N], level_inv[MAX_N], gx[MAX_N], CX[MAX_N], GX[MAX_N], invG[MAX_N];

ll quick_pow(ll bas, ll tim)
{
    ll ans = 1;
    bas %= mod;
    while (tim)
    {
        if (tim & 1)
            ans = ans * bas % mod;
        bas = bas * bas % mod;
        tim >>= 1;
    }
    return ans;
}

inline void ntt(ll *arr, int limit, int dft)
{
    for (int i = 0; i < limit; i++)
        if (i < rev[i])
            swap(arr[i], arr[rev[i]]);
    for (int step = 1; step < limit; step <<= 1)
    {
        ll omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1));
        for (int j = 0; j < limit; j += (step << 1))
        {
            ll omega_nk = 1;
            for (int k = j; k < j + step; k++, omega_nk = (omega_nk * omega_n % mod))
            {
                ll x = arr[k], y = omega_nk * arr[k + step] % mod;
                arr[k] = (x + y) % mod, arr[k + step] = (x - y + mod) % mod;
            }
        }
    }
    if (dft != 1)
    {
        ll inv = quick_pow(limit, mod - 2);
        for (int i = 0; i < limit; i++)
            arr[i] = (arr[i] * inv % mod);
    }
}

void poly_inverse(int deg, ll *a, ll *b)
{
    if (deg == 1)
        return (void)(b[0] = quick_pow(a[0], mod - 2));
    poly_inverse((deg + 1) >> 1, a, b);

    ll limit = 2, mx_bit = 1;
    while ((deg << 1) > limit)
        limit <<= 1, mx_bit++;
    for (int i = 1; i < limit; i++)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1)));

    for (int i = 0; i < deg; i++)
        tmp[i] = a[i];
    for (int i = deg; i < limit; i++)
        tmp[i] = 0;
    ntt(tmp, limit, 1), ntt(b, limit, 1);

    for (int i = 0; i < limit; i++)
        b[i] = 1LL * ((2LL - tmp[i] * b[i] % mod + mod) % mod) * b[i] % mod;

    ntt(b, limit, -1);
    for (int i = deg; i < limit; i++)
        b[i] = 0;
}

int main()
{
    Gi = quick_pow(G, mod - 2);
    scanf("%d", &n);

    level[1] = level[0] = 1, level_inv[0] = 1;
    for (int i = 2; i <= n; i++)
        level[i] = 1LL * level[i - 1] * i % mod;
    level_inv[n] = quick_pow(level[n], mod - 2);
    for (int i = n - 1; i >= 1; i--)
        level_inv[i] = 1LL * level_inv[i + 1] * (i + 1) % mod;

    gx[1] = gx[0] = 1;
    for (int i = 2; i <= n; i++)
        gx[i] = quick_pow(2, ((1LL * i * (i - 1)) >> 1) % (mod - 1));

    for (int i = 0; i <= n; i++)
        GX[i] = 1LL * gx[i] * level_inv[i] % mod;
    for (int i = 1; i <= n; i++)
        CX[i] = 1LL * gx[i] * level_inv[i - 1] % mod;

    poly_inverse(n + 1, GX, invG);

    ll limit = 2, mx_bit = 1;
    while (((n + 1) << 1) > limit)
        limit <<= 1, mx_bit++;
    for (int i = 1; i < limit; i++)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1)));

    ntt(invG, limit, 1), ntt(CX, limit, 1);
    for (int i = 0; i < limit; i++)
        invG[i] = (invG[i] * CX[i]) % mod;
    ntt(invG, limit, -1);
    ll ans = invG[n] * quick_pow(level_inv[n - 1], mod - 2) % mod;
    while (ans < 0)
        ans += mod;
    ans %= mod;
    printf("%lld", ans);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *