Loading [MathJax]/extensions/tex2jax.js

「Fortuna OJ」Apr 19 省选 A 组 – 解题报告

A – Directory Traversal

没什么好说的,比较水的一个换根 DP。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// traversal.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e5 + 200;
typedef long long ll;
int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot;
ll dp[MAX_N], g[MAX_N];
char name[20];
bool tag[MAX_N];
struct edge
{
int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
edges[current].to = dst, edges[current].nxt = head[src];
head[src] = current++;
}
void dfs(int u, int fa)
{
for (int i = head[u]; i != -1; i = edges[i].nxt)
{
dfs(edges[i].to, u), siz[u] += siz[edges[i].to];
dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true);
}
}
void collect(int u, int fa)
{
if (fa != 0)
g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL);
for (int i = head[u]; i != -1; i = edges[i].nxt)
collect(edges[i].to, u);
}
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
int main()
{
fileIO("traversal");
memset(head, -1, sizeof(head));
scanf("%d", &n);
for (int i = 1, m, val; i <= n; i++)
{
scanf("%s%d", name, &m), wi[i] = strlen(name);
if (m == 0)
siz[i] = 1, tag[i] = true, tot++;
else
while (m--)
scanf("%d", &val), addpath(i, val);
}
dfs(1, 0), collect(1, 0);
ll ans = 2e18;
for (int i = 1; i <= n; i++)
ans = min(ans, dp[i] + g[i]);
printf("%lld\n", ans);
return 0;
}
// traversal.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e5 + 200; typedef long long ll; int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot; ll dp[MAX_N], g[MAX_N]; char name[20]; bool tag[MAX_N]; struct edge { int to, nxt; } edges[MAX_N << 1]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } void dfs(int u, int fa) { for (int i = head[u]; i != -1; i = edges[i].nxt) { dfs(edges[i].to, u), siz[u] += siz[edges[i].to]; dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true); } } void collect(int u, int fa) { if (fa != 0) g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL); for (int i = head[u]; i != -1; i = edges[i].nxt) collect(edges[i].to, u); } void fileIO(string src) { freopen((src + ".in").c_str(), "r", stdin); freopen((src + ".out").c_str(), "w", stdout); } int main() { fileIO("traversal"); memset(head, -1, sizeof(head)); scanf("%d", &n); for (int i = 1, m, val; i <= n; i++) { scanf("%s%d", name, &m), wi[i] = strlen(name); if (m == 0) siz[i] = 1, tag[i] = true, tot++; else while (m--) scanf("%d", &val), addpath(i, val); } dfs(1, 0), collect(1, 0); ll ans = 2e18; for (int i = 1; i <= n; i++) ans = min(ans, dp[i] + g[i]); printf("%lld\n", ans); return 0; }
// traversal.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200;

typedef long long ll;

int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot;
ll dp[MAX_N], g[MAX_N];
char name[20];
bool tag[MAX_N];

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

void dfs(int u, int fa)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
    {
        dfs(edges[i].to, u), siz[u] += siz[edges[i].to];
        dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true);
    }
}

void collect(int u, int fa)
{
    if (fa != 0)
        g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL);
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        collect(edges[i].to, u);
}

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

int main()
{
    fileIO("traversal");
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 1, m, val; i <= n; i++)
    {
        scanf("%s%d", name, &m), wi[i] = strlen(name);
        if (m == 0)
            siz[i] = 1, tag[i] = true, tot++;
        else
            while (m--)
                scanf("%d", &val), addpath(i, val);
    }
    dfs(1, 0), collect(1, 0);
    ll ans = 2e18;
    for (int i = 1; i <= n; i++)
        ans = min(ans, dp[i] + g[i]);
    printf("%lld\n", ans);
    return 0;
}

B – 迫害 DJ

这道题其实在考场上能看得出来是找 Fibonacci 循环节,但是我不太会,所以打了个暴力走人。事后发现大家都不会。

首先,我们知道这道题本质是在求多个模域下的 \(g_x\)。\(g_x\) 可以被改写成 \(g_x = b \times Fib_{2n} – a \times Fib_{2(n – 1)}\)。所以问题被简化成求 Fibonacci 数列在若干个模域下的值。

求循环节比较人类智慧,设 \(mod = \prod p_i^{c_i}\),则:

\[ f(p) = \begin{cases} 3, p = 2 \\ 8, p = 3 \\ 5, p = 20 \\ p – 1, p \bmod 5 = 1, 4 \\ 2p + 2, \text{otherwise} \end{cases} \\ loop = \text{lcm} \{ f(p_i) \times p_i^{c_i – 1} \} \]

证明就不会了,有兴趣就去搜一搜吧。

知道这个循环节之后,考虑设计子任务 \(solve(n, k, mod)\),然后递归求即可。快速求 \(g_n\) 直接矩阵乘法即可。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// hakugai.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 50;
int T, a, b, n, level, mod;
typedef long long ll;
ll loops[MAX_N], tot, cmod;
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
struct matrix
{
ll mat[2][2];
ll *operator[](const int &id) { return mat[id]; }
void clear() { memset(mat, 0, sizeof(mat)); };
void epsilon() { mat[0][0] = mat[1][1] = 1; }
matrix operator*(const matrix &rhs)
{
matrix ret;
ret.clear();
for (int i = 0; i < 2; i++)
for (int k = 0; k < 2; k++)
if (mat[i][k])
for (int j = 0; j < 2; j++)
if (rhs.mat[k][j])
ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod;
return ret;
}
matrix operator^(const ll &rhs)
{
matrix ret, bas = *this;
ret.clear(), ret.epsilon();
ll tim = rhs;
while (tim)
{
if (tim & 1LL)
ret = ret * bas;
bas = bas * bas;
tim >>= 1;
}
return ret;
}
} init, trans;
ll getGn(ll x, ll cm)
{
cmod = cm;
if (x == 0)
return a;
if (x == 1)
return b;
return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod;
}
ll getFn(ll pbase)
{
if (pbase == 2)
return 3;
if (pbase == 3)
return 8;
if (pbase == 5)
return 20;
if (pbase % 5 == 1 || pbase % 5 == 4)
return pbase - 1;
return 2 * pbase + 2;
}
ll getLoopLen(ll x)
{
tot = 0;
ll acc = x;
for (int i = 2; 1LL * i * i <= acc; i++)
if (acc % i == 0)
{
loops[++tot] = getFn(i);
ll cnt = 1;
while (acc % i == 0)
acc /= i, cnt *= i;
cnt /= i, loops[tot] *= cnt;
}
if (acc != 1)
loops[++tot] = getFn(acc);
ll loop = loops[1];
for (int i = 2; i <= tot; i++)
loop = (loop / __gcd(loop, loops[i])) * loops[i];
return loop;
}
ll solve(ll n_, ll k_, ll cm)
{
if (k_ == 0)
return n_ % cm;
ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop);
ll res = getGn(now, cm);
return res;
}
int main()
{
// fileIO("hakugai");
scanf("%d", &T);
trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1;
init.clear(), init[0][0] = 0, init[0][1] = 1;
while (T--)
{
scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod);
// get loop;
printf("%lld\n", solve(n, level, mod));
}
return 0;
}
// hakugai.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 50; int T, a, b, n, level, mod; typedef long long ll; ll loops[MAX_N], tot, cmod; void fileIO(string src) { freopen((src + ".in").c_str(), "r", stdin); freopen((src + ".out").c_str(), "w", stdout); } struct matrix { ll mat[2][2]; ll *operator[](const int &id) { return mat[id]; } void clear() { memset(mat, 0, sizeof(mat)); }; void epsilon() { mat[0][0] = mat[1][1] = 1; } matrix operator*(const matrix &rhs) { matrix ret; ret.clear(); for (int i = 0; i < 2; i++) for (int k = 0; k < 2; k++) if (mat[i][k]) for (int j = 0; j < 2; j++) if (rhs.mat[k][j]) ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod; return ret; } matrix operator^(const ll &rhs) { matrix ret, bas = *this; ret.clear(), ret.epsilon(); ll tim = rhs; while (tim) { if (tim & 1LL) ret = ret * bas; bas = bas * bas; tim >>= 1; } return ret; } } init, trans; ll getGn(ll x, ll cm) { cmod = cm; if (x == 0) return a; if (x == 1) return b; return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod; } ll getFn(ll pbase) { if (pbase == 2) return 3; if (pbase == 3) return 8; if (pbase == 5) return 20; if (pbase % 5 == 1 || pbase % 5 == 4) return pbase - 1; return 2 * pbase + 2; } ll getLoopLen(ll x) { tot = 0; ll acc = x; for (int i = 2; 1LL * i * i <= acc; i++) if (acc % i == 0) { loops[++tot] = getFn(i); ll cnt = 1; while (acc % i == 0) acc /= i, cnt *= i; cnt /= i, loops[tot] *= cnt; } if (acc != 1) loops[++tot] = getFn(acc); ll loop = loops[1]; for (int i = 2; i <= tot; i++) loop = (loop / __gcd(loop, loops[i])) * loops[i]; return loop; } ll solve(ll n_, ll k_, ll cm) { if (k_ == 0) return n_ % cm; ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop); ll res = getGn(now, cm); return res; } int main() { // fileIO("hakugai"); scanf("%d", &T); trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1; init.clear(), init[0][0] = 0, init[0][1] = 1; while (T--) { scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod); // get loop; printf("%lld\n", solve(n, level, mod)); } return 0; }
// hakugai.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 50;

int T, a, b, n, level, mod;

typedef long long ll;

ll loops[MAX_N], tot, cmod;

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

struct matrix
{
    ll mat[2][2];
    ll *operator[](const int &id) { return mat[id]; }
    void clear() { memset(mat, 0, sizeof(mat)); };
    void epsilon() { mat[0][0] = mat[1][1] = 1; }
    matrix operator*(const matrix &rhs)
    {
        matrix ret;
        ret.clear();
        for (int i = 0; i < 2; i++)
            for (int k = 0; k < 2; k++)
                if (mat[i][k])
                    for (int j = 0; j < 2; j++)
                        if (rhs.mat[k][j])
                            ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod;
        return ret;
    }
    matrix operator^(const ll &rhs)
    {
        matrix ret, bas = *this;
        ret.clear(), ret.epsilon();
        ll tim = rhs;
        while (tim)
        {
            if (tim & 1LL)
                ret = ret * bas;
            bas = bas * bas;
            tim >>= 1;
        }
        return ret;
    }
} init, trans;

ll getGn(ll x, ll cm)
{
    cmod = cm;
    if (x == 0)
        return a;
    if (x == 1)
        return b;
    return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod;
}

ll getFn(ll pbase)
{
    if (pbase == 2)
        return 3;
    if (pbase == 3)
        return 8;
    if (pbase == 5)
        return 20;
    if (pbase % 5 == 1 || pbase % 5 == 4)
        return pbase - 1;
    return 2 * pbase + 2;
}

ll getLoopLen(ll x)
{
    tot = 0;
    ll acc = x;
    for (int i = 2; 1LL * i * i <= acc; i++)
        if (acc % i == 0)
        {
            loops[++tot] = getFn(i);
            ll cnt = 1;
            while (acc % i == 0)
                acc /= i, cnt *= i;
            cnt /= i, loops[tot] *= cnt;
        }
    if (acc != 1)
        loops[++tot] = getFn(acc);
    ll loop = loops[1];
    for (int i = 2; i <= tot; i++)
        loop = (loop / __gcd(loop, loops[i])) * loops[i];
    return loop;
}

ll solve(ll n_, ll k_, ll cm)
{
    if (k_ == 0)
        return n_ % cm;
    ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop);
    ll res = getGn(now, cm);
    return res;
}

int main()
{
    // fileIO("hakugai");
    scanf("%d", &T);
    trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1;
    init.clear(), init[0][0] = 0, init[0][1] = 1;
    while (T--)
    {
        scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod);
        // get loop;
        printf("%lld\n", solve(n, level, mod));
    }
    return 0;
}

B – 夕张的改造

这题就比较神仙了。

考虑矩阵树定理,直接做没法限制 \(k\) 这个东西。人类智慧告诉我们,如果我们把未加入的边的边权赋为 \(x\),那么最后我们矩阵树求出来的东西肯定就是一个关于 \(x\) 的多项式,那么前 \(k + 1\) 项的系数就是各自的方案数。

我们让 \(x = 1, 2, \dots, n\) 求出点值,然后再用高斯消元差值出来得到多项式,再加前 \(k + 1\) 项系数即可。这个题我个人很喜欢。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// kaisou.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 55, mod = 998244353;
int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N];
bool mp[MAX_N][MAX_N];
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int gauss()
{
int res = 1;
for (int i = 0; i < n - 1; i++)
{
int key = i;
for (int j = i; j < n - 1; j++)
if (mat[j][i] != 0)
{
key = j;
break;
}
if (key != i)
{
res = mod - res;
for (int j = i; j < n - 1; j++)
swap(mat[i][j], mat[key][j]);
}
int inv = fpow(mat[i][i], mod - 2);
for (int j = i + 1; j < n - 1; j++)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i; k < n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 0; i < n - 1; i++)
res = 1LL * res * mat[i][i] % mod;
return res;
}
int main()
{
// fileIO("kaisou");
scanf("%d%d", &n, &limit);
for (int i = 1, fa; i < n; i++)
scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true;
for (int x = 1; x <= n; x++)
{
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
if (i != j)
if (!mp[i][j])
mat[i][j] = mod - x, mat[i][i] += x;
else
mat[i][j] = mod - 1, mat[i][i]++;
poly_val[x] = gauss();
}
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
{
mat[i][0] = 1;
for (int j = 1; j < n; j++)
mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod;
mat[i][n] = poly_val[i + 1];
}
for (int i = 0; i < n; i++)
{
int key = i;
for (int j = i; j < n; j++)
if (mat[j][i] > 0)
{
key = j;
break;
}
if (key != i)
for (int j = i; j <= n; j++)
swap(mat[i][j], mat[key][j]);
int inv = fpow(mat[i][i], mod - 2);
for (int j = 0; j < n; j++)
if (i != j)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i + 1; k <= n; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 0; i < n; i++)
poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod;
int ans = 0;
for (int i = 0; i <= limit; i++)
ans = (0LL + ans + poly_coeff[i]) % mod;
printf("%d\n", ans);
return 0;
}
// kaisou.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 55, mod = 998244353; int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N]; bool mp[MAX_N][MAX_N]; void fileIO(string src) { freopen((src + ".in").c_str(), "r", stdin); freopen((src + ".out").c_str(), "w", stdout); } int fpow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } int gauss() { int res = 1; for (int i = 0; i < n - 1; i++) { int key = i; for (int j = i; j < n - 1; j++) if (mat[j][i] != 0) { key = j; break; } if (key != i) { res = mod - res; for (int j = i; j < n - 1; j++) swap(mat[i][j], mat[key][j]); } int inv = fpow(mat[i][i], mod - 2); for (int j = i + 1; j < n - 1; j++) { int rate = 1LL * mat[j][i] * inv % mod; for (int k = i; k < n - 1; k++) mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod; } } for (int i = 0; i < n - 1; i++) res = 1LL * res * mat[i][i] % mod; return res; } int main() { // fileIO("kaisou"); scanf("%d%d", &n, &limit); for (int i = 1, fa; i < n; i++) scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true; for (int x = 1; x <= n; x++) { memset(mat, 0, sizeof(mat)); for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (i != j) if (!mp[i][j]) mat[i][j] = mod - x, mat[i][i] += x; else mat[i][j] = mod - 1, mat[i][i]++; poly_val[x] = gauss(); } memset(mat, 0, sizeof(mat)); for (int i = 0; i < n; i++) { mat[i][0] = 1; for (int j = 1; j < n; j++) mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod; mat[i][n] = poly_val[i + 1]; } for (int i = 0; i < n; i++) { int key = i; for (int j = i; j < n; j++) if (mat[j][i] > 0) { key = j; break; } if (key != i) for (int j = i; j <= n; j++) swap(mat[i][j], mat[key][j]); int inv = fpow(mat[i][i], mod - 2); for (int j = 0; j < n; j++) if (i != j) { int rate = 1LL * mat[j][i] * inv % mod; for (int k = i + 1; k <= n; k++) mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod; } } for (int i = 0; i < n; i++) poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod; int ans = 0; for (int i = 0; i <= limit; i++) ans = (0LL + ans + poly_coeff[i]) % mod; printf("%d\n", ans); return 0; }
// kaisou.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 55, mod = 998244353;

int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N];
bool mp[MAX_N][MAX_N];

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

int gauss()
{
    int res = 1;
    for (int i = 0; i < n - 1; i++)
    {
        int key = i;
        for (int j = i; j < n - 1; j++)
            if (mat[j][i] != 0)
            {
                key = j;
                break;
            }
        if (key != i)
        {
            res = mod - res;
            for (int j = i; j < n - 1; j++)
                swap(mat[i][j], mat[key][j]);
        }
        int inv = fpow(mat[i][i], mod - 2);
        for (int j = i + 1; j < n - 1; j++)
        {
            int rate = 1LL * mat[j][i] * inv % mod;
            for (int k = i; k < n - 1; k++)
                mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
        }
    }
    for (int i = 0; i < n - 1; i++)
        res = 1LL * res * mat[i][i] % mod;
    return res;
}

int main()
{
    // fileIO("kaisou");
    scanf("%d%d", &n, &limit);
    for (int i = 1, fa; i < n; i++)
        scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true;
    for (int x = 1; x <= n; x++)
    {
        memset(mat, 0, sizeof(mat));
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                if (i != j)
                    if (!mp[i][j])
                        mat[i][j] = mod - x, mat[i][i] += x;
                    else
                        mat[i][j] = mod - 1, mat[i][i]++;
        poly_val[x] = gauss();
    }
    memset(mat, 0, sizeof(mat));
    for (int i = 0; i < n; i++)
    {
        mat[i][0] = 1;
        for (int j = 1; j < n; j++)
            mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod;
        mat[i][n] = poly_val[i + 1];
    }
    for (int i = 0; i < n; i++)
    {
        int key = i;
        for (int j = i; j < n; j++)
            if (mat[j][i] > 0)
            {
                key = j;
                break;
            }
        if (key != i)
            for (int j = i; j <= n; j++)
                swap(mat[i][j], mat[key][j]);
        int inv = fpow(mat[i][i], mod - 2);
        for (int j = 0; j < n; j++)
            if (i != j)
            {
                int rate = 1LL * mat[j][i] * inv % mod;
                for (int k = i + 1; k <= n; k++)
                    mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
            }
    }
    for (int i = 0; i < n; i++)
        poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod;
    int ans = 0;
    for (int i = 0; i <= limit; i++)
        ans = (0LL + ans + poly_coeff[i]) % mod;
    printf("%d\n", ans);
    return 0;
}

D – 亚特兰大

这题比赛的时候用了一个比较猥琐的方法,过掉了数据不强的 70 分。正解还是很有意思的。

我们考虑把一条边的每个约数都拆成一条独立的边,然后我们枚举每一个出现过的约数 \(d\),记连通点对的个数为 \(f(d)\),那么答案就是 \(\sum_{d = 1} \mu(d) f(d)\)。这一个步骤非常的 nb。

有了这个之后我们就可以来算了。考虑到 \(q\) 很小,我们离线下来之后也分解出来,不过要标上时间。

我们先处理那些没动过的边,用并查集做。动完这些之后,我们就需要来处理询问的边。我们这个时候需要可撤销的并查集,用启发式合并来实现。最后累加到答案里面去。

我当天下午想了半个小时具体实现之后还是写不动,最后还是去看题解的代码了(wtcl)。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// atoranta.cpp
// #pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e6 + 200;
typedef long long ll;
int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N];
ll ans[MAX_N], sum;
bool vis[MAX_N];
stack<int> stk;
struct node
{
int id, time_frame;
};
vector<node> frames[MAX_N], qframes[MAX_N];
vector<int> facts[MAX_N];
struct segment
{
int x, y, z;
} org[MAX_N], qseg[MAX_N];
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
void sieve()
{
mu[1] = 1;
for (int i = 2; i < MAX_N; i++)
{
if (!vis[i])
primes[++tot] = i, mu[i] = -1;
for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
{
vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i];
if (i % primes[j] == 0)
{
mu[i * primes[j]] = 0;
break;
}
}
}
for (int i = 2; i < MAX_N; i++)
if (!vis[i])
for (int j = i; j <= 1e6; j += i)
facts[j].push_back(i);
}
int find(int x)
{
while (mem[x])
x = mem[x];
return x;
}
void merge(int x, int y)
{
int fx = find(x), fy = find(y);
if (size[fx] < size[fy])
swap(fx, fy);
mem[fy] = fx, sum += 1LL * size[fx] * size[fy];
size[fx] += size[fy], stk.push(fy);
}
void undo()
{
int fy = stk.top(), fx = mem[fy];
size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0;
stk.pop();
}
void insert(int w, int id, int time_frame)
{
node u = node{id, time_frame};
int m = facts[w].size();
for (int stat = 0; stat < (1 << m); stat++)
{
int pans = 1;
for (int i = 0; i < m; i++)
if (stat & (1 << i))
pans *= facts[w][i];
frames[pans].push_back(u);
}
}
void solve(int x)
{
sum = 0;
for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1)
{
gx = i - 1;
int ctime = frames[x][i].time_frame;
while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime)
gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y);
ll pans = 1LL * mu[x] * sum;
if (ctime != -1)
{
ans[ctime] += pans;
for (int tot = i; tot <= gx; tot++)
undo();
}
else
{
for (int k = gx + 1; k < siz; k++)
visit[frames[x][k].time_frame] = x;
for (int k = 0; k <= q; k++)
if (visit[k] != x)
ans[k] += pans;
}
}
while (!stk.empty())
undo();
}
int main()
{
fileIO("atoranta");
scanf("%d", &n), sieve();
for (int i = 1; i <= n - 1; i++)
scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z);
scanf("%d", &q);
for (int i = 1; i <= q; i++)
scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i;
for (int i = 1; i < n; i++)
if (last_time[i] == 0)
insert(org[i].z, i, -1);
for (int i = 1; i <= q; i++)
{
node u = node{qseg[i].x, qseg[i].y};
// marks;
for (int j = i; j <= q; j++)
if (j == i || qseg[i].x != qseg[j].x)
qframes[j].push_back(u);
else
break;
bool flag = false;
for (int j = 1; j < i; j++)
if (qseg[i].x == qseg[j].x)
{
flag = true;
break;
}
// first;
if (!flag)
{
node v = node{qseg[i].x, org[qseg[i].x].z};
for (int j = 0; j < i; j++)
qframes[j].push_back(v);
}
}
for (int i = 0; i <= q; i++)
for (int j = 0, siz = qframes[i].size(); j < siz; j++)
insert(qframes[i][j].time_frame, qframes[i][j].id, i);
for (int i = 1; i <= n; i++)
size[i] = 1;
for (int i = 1; i <= 1e6; i++)
if (!frames[i].empty())
solve(i);
for (int i = 0; i <= q; i++)
printf("%lld\n", ans[i]);
return 0;
}
// atoranta.cpp // #pragma GCC optimize(2) #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e6 + 200; typedef long long ll; int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N]; ll ans[MAX_N], sum; bool vis[MAX_N]; stack<int> stk; struct node { int id, time_frame; }; vector<node> frames[MAX_N], qframes[MAX_N]; vector<int> facts[MAX_N]; struct segment { int x, y, z; } org[MAX_N], qseg[MAX_N]; void fileIO(string src) { freopen((src + ".in").c_str(), "r", stdin); freopen((src + ".out").c_str(), "w", stdout); } void sieve() { mu[1] = 1; for (int i = 2; i < MAX_N; i++) { if (!vis[i]) primes[++tot] = i, mu[i] = -1; for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++) { vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i]; if (i % primes[j] == 0) { mu[i * primes[j]] = 0; break; } } } for (int i = 2; i < MAX_N; i++) if (!vis[i]) for (int j = i; j <= 1e6; j += i) facts[j].push_back(i); } int find(int x) { while (mem[x]) x = mem[x]; return x; } void merge(int x, int y) { int fx = find(x), fy = find(y); if (size[fx] < size[fy]) swap(fx, fy); mem[fy] = fx, sum += 1LL * size[fx] * size[fy]; size[fx] += size[fy], stk.push(fy); } void undo() { int fy = stk.top(), fx = mem[fy]; size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0; stk.pop(); } void insert(int w, int id, int time_frame) { node u = node{id, time_frame}; int m = facts[w].size(); for (int stat = 0; stat < (1 << m); stat++) { int pans = 1; for (int i = 0; i < m; i++) if (stat & (1 << i)) pans *= facts[w][i]; frames[pans].push_back(u); } } void solve(int x) { sum = 0; for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1) { gx = i - 1; int ctime = frames[x][i].time_frame; while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime) gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y); ll pans = 1LL * mu[x] * sum; if (ctime != -1) { ans[ctime] += pans; for (int tot = i; tot <= gx; tot++) undo(); } else { for (int k = gx + 1; k < siz; k++) visit[frames[x][k].time_frame] = x; for (int k = 0; k <= q; k++) if (visit[k] != x) ans[k] += pans; } } while (!stk.empty()) undo(); } int main() { fileIO("atoranta"); scanf("%d", &n), sieve(); for (int i = 1; i <= n - 1; i++) scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z); scanf("%d", &q); for (int i = 1; i <= q; i++) scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i; for (int i = 1; i < n; i++) if (last_time[i] == 0) insert(org[i].z, i, -1); for (int i = 1; i <= q; i++) { node u = node{qseg[i].x, qseg[i].y}; // marks; for (int j = i; j <= q; j++) if (j == i || qseg[i].x != qseg[j].x) qframes[j].push_back(u); else break; bool flag = false; for (int j = 1; j < i; j++) if (qseg[i].x == qseg[j].x) { flag = true; break; } // first; if (!flag) { node v = node{qseg[i].x, org[qseg[i].x].z}; for (int j = 0; j < i; j++) qframes[j].push_back(v); } } for (int i = 0; i <= q; i++) for (int j = 0, siz = qframes[i].size(); j < siz; j++) insert(qframes[i][j].time_frame, qframes[i][j].id, i); for (int i = 1; i <= n; i++) size[i] = 1; for (int i = 1; i <= 1e6; i++) if (!frames[i].empty()) solve(i); for (int i = 0; i <= q; i++) printf("%lld\n", ans[i]); return 0; }
// atoranta.cpp
// #pragma GCC optimize(2)
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e6 + 200;

typedef long long ll;

int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N];
ll ans[MAX_N], sum;
bool vis[MAX_N];
stack<int> stk;

struct node
{
    int id, time_frame;
};
vector<node> frames[MAX_N], qframes[MAX_N];
vector<int> facts[MAX_N];

struct segment
{
    int x, y, z;
} org[MAX_N], qseg[MAX_N];

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

void sieve()
{
    mu[1] = 1;
    for (int i = 2; i < MAX_N; i++)
    {
        if (!vis[i])
            primes[++tot] = i, mu[i] = -1;
        for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
        {
            vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i];
            if (i % primes[j] == 0)
            {
                mu[i * primes[j]] = 0;
                break;
            }
        }
    }
    for (int i = 2; i < MAX_N; i++)
        if (!vis[i])
            for (int j = i; j <= 1e6; j += i)
                facts[j].push_back(i);
}

int find(int x)
{
    while (mem[x])
        x = mem[x];
    return x;
}

void merge(int x, int y)
{
    int fx = find(x), fy = find(y);
    if (size[fx] < size[fy])
        swap(fx, fy);
    mem[fy] = fx, sum += 1LL * size[fx] * size[fy];
    size[fx] += size[fy], stk.push(fy);
}

void undo()
{
    int fy = stk.top(), fx = mem[fy];
    size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0;
    stk.pop();
}

void insert(int w, int id, int time_frame)
{
    node u = node{id, time_frame};
    int m = facts[w].size();
    for (int stat = 0; stat < (1 << m); stat++)
    {
        int pans = 1;
        for (int i = 0; i < m; i++)
            if (stat & (1 << i))
                pans *= facts[w][i];
        frames[pans].push_back(u);
    }
}

void solve(int x)
{
    sum = 0;
    for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1)
    {
        gx = i - 1;
        int ctime = frames[x][i].time_frame;
        while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime)
            gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y);
        ll pans = 1LL * mu[x] * sum;
        if (ctime != -1)
        {
            ans[ctime] += pans;
            for (int tot = i; tot <= gx; tot++)
                undo();
        }
        else
        {
            for (int k = gx + 1; k < siz; k++)
                visit[frames[x][k].time_frame] = x;
            for (int k = 0; k <= q; k++)
                if (visit[k] != x)
                    ans[k] += pans;
        }
    }
    while (!stk.empty())
        undo();
}

int main()
{
    fileIO("atoranta");
    scanf("%d", &n), sieve();
    for (int i = 1; i <= n - 1; i++)
        scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z);
    scanf("%d", &q);
    for (int i = 1; i <= q; i++)
        scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i;
    for (int i = 1; i < n; i++)
        if (last_time[i] == 0)
            insert(org[i].z, i, -1);
    for (int i = 1; i <= q; i++)
    {
        node u = node{qseg[i].x, qseg[i].y};
        // marks;
        for (int j = i; j <= q; j++)
            if (j == i || qseg[i].x != qseg[j].x)
                qframes[j].push_back(u);
            else
                break;
        bool flag = false;
        for (int j = 1; j < i; j++)
            if (qseg[i].x == qseg[j].x)
            {
                flag = true;
                break;
            }
        // first;
        if (!flag)
        {
            node v = node{qseg[i].x, org[qseg[i].x].z};
            for (int j = 0; j < i; j++)
                qframes[j].push_back(v);
        }
    }
    for (int i = 0; i <= q; i++)
        for (int j = 0, siz = qframes[i].size(); j < siz; j++)
            insert(qframes[i][j].time_frame, qframes[i][j].id, i);
    for (int i = 1; i <= n; i++)
        size[i] = 1;
    for (int i = 1; i <= 1e6; i++)
        if (!frames[i].empty())
            solve(i);
    for (int i = 0; i <= q; i++)
        printf("%lld\n", ans[i]);
    return 0;
}

 

One Comment

Leave a Reply

Your email address will not be published. Required fields are marked *