A – Directory Traversal
没什么好说的,比较水的一个换根 DP。
const int MAX_N = 1e5 + 200;
int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot;
void addpath(int src, int dst)
edges[current].to = dst, edges[current].nxt = head[src];
for (int i = head[u]; i != -1; i = edges[i].nxt)
dfs(edges[i].to, u), siz[u] += siz[edges[i].to];
dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true);
void collect(int u, int fa)
g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL);
for (int i = head[u]; i != -1; i = edges[i].nxt)
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
memset(head, -1, sizeof(head));
for (int i = 1, m, val; i <= n; i++)
scanf("%s%d", name, &m), wi[i] = strlen(name);
siz[i] = 1, tag[i] = true, tot++;
scanf("%d", &val), addpath(i, val);
dfs(1, 0), collect(1, 0);
for (int i = 1; i <= n; i++)
ans = min(ans, dp[i] + g[i]);
// traversal.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e5 + 200;
typedef long long ll;
int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot;
ll dp[MAX_N], g[MAX_N];
char name[20];
bool tag[MAX_N];
struct edge
{
int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
edges[current].to = dst, edges[current].nxt = head[src];
head[src] = current++;
}
void dfs(int u, int fa)
{
for (int i = head[u]; i != -1; i = edges[i].nxt)
{
dfs(edges[i].to, u), siz[u] += siz[edges[i].to];
dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true);
}
}
void collect(int u, int fa)
{
if (fa != 0)
g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL);
for (int i = head[u]; i != -1; i = edges[i].nxt)
collect(edges[i].to, u);
}
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
int main()
{
fileIO("traversal");
memset(head, -1, sizeof(head));
scanf("%d", &n);
for (int i = 1, m, val; i <= n; i++)
{
scanf("%s%d", name, &m), wi[i] = strlen(name);
if (m == 0)
siz[i] = 1, tag[i] = true, tot++;
else
while (m--)
scanf("%d", &val), addpath(i, val);
}
dfs(1, 0), collect(1, 0);
ll ans = 2e18;
for (int i = 1; i <= n; i++)
ans = min(ans, dp[i] + g[i]);
printf("%lld\n", ans);
return 0;
}
// traversal.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e5 + 200;
typedef long long ll;
int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot;
ll dp[MAX_N], g[MAX_N];
char name[20];
bool tag[MAX_N];
struct edge
{
int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
edges[current].to = dst, edges[current].nxt = head[src];
head[src] = current++;
}
void dfs(int u, int fa)
{
for (int i = head[u]; i != -1; i = edges[i].nxt)
{
dfs(edges[i].to, u), siz[u] += siz[edges[i].to];
dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true);
}
}
void collect(int u, int fa)
{
if (fa != 0)
g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL);
for (int i = head[u]; i != -1; i = edges[i].nxt)
collect(edges[i].to, u);
}
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
int main()
{
fileIO("traversal");
memset(head, -1, sizeof(head));
scanf("%d", &n);
for (int i = 1, m, val; i <= n; i++)
{
scanf("%s%d", name, &m), wi[i] = strlen(name);
if (m == 0)
siz[i] = 1, tag[i] = true, tot++;
else
while (m--)
scanf("%d", &val), addpath(i, val);
}
dfs(1, 0), collect(1, 0);
ll ans = 2e18;
for (int i = 1; i <= n; i++)
ans = min(ans, dp[i] + g[i]);
printf("%lld\n", ans);
return 0;
}
B – 迫害 DJ
这道题其实在考场上能看得出来是找 Fibonacci 循环节,但是我不太会,所以打了个暴力走人。事后发现大家都不会。
首先,我们知道这道题本质是在求多个模域下的 \(g_x\)。\(g_x\) 可以被改写成 \(g_x = b \times Fib_{2n} – a \times Fib_{2(n – 1)}\)。所以问题被简化成求 Fibonacci 数列在若干个模域下的值。
求循环节比较人类智慧,设 \(mod = \prod p_i^{c_i}\),则:
\[ f(p) = \begin{cases} 3, p = 2 \\ 8, p = 3 \\ 5, p = 20 \\ p – 1, p \bmod 5 = 1, 4 \\ 2p + 2, \text{otherwise} \end{cases} \\ loop = \text{lcm} \{ f(p_i) \times p_i^{c_i – 1} \} \]
证明就不会了,有兴趣就去搜一搜吧。
知道这个循环节之后,考虑设计子任务 \(solve(n, k, mod)\),然后递归求即可。快速求 \(g_n\) 直接矩阵乘法即可。
int T, a, b, n, level, mod;
ll loops[MAX_N], tot, cmod;
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
ll *operator[](const int &id) { return mat[id]; }
void clear() { memset(mat, 0, sizeof(mat)); };
void epsilon() { mat[0][0] = mat[1][1] = 1; }
matrix operator*(const matrix &rhs)
for (int i = 0; i < 2; i++)
for (int k = 0; k < 2; k++)
for (int j = 0; j < 2; j++)
ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod;
matrix operator^(const ll &rhs)
ret.clear(), ret.epsilon();
return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod;
if (pbase % 5 == 1 || pbase % 5 == 4)
for (int i = 2; 1LL * i * i <= acc; i++)
cnt /= i, loops[tot] *= cnt;
loops[++tot] = getFn(acc);
for (int i = 2; i <= tot; i++)
loop = (loop / __gcd(loop, loops[i])) * loops[i];
ll solve(ll n_, ll k_, ll cm)
ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop);
trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1;
init.clear(), init[0][0] = 0, init[0][1] = 1;
scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod);
printf("%lld\n", solve(n, level, mod));
// hakugai.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 50;
int T, a, b, n, level, mod;
typedef long long ll;
ll loops[MAX_N], tot, cmod;
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
struct matrix
{
ll mat[2][2];
ll *operator[](const int &id) { return mat[id]; }
void clear() { memset(mat, 0, sizeof(mat)); };
void epsilon() { mat[0][0] = mat[1][1] = 1; }
matrix operator*(const matrix &rhs)
{
matrix ret;
ret.clear();
for (int i = 0; i < 2; i++)
for (int k = 0; k < 2; k++)
if (mat[i][k])
for (int j = 0; j < 2; j++)
if (rhs.mat[k][j])
ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod;
return ret;
}
matrix operator^(const ll &rhs)
{
matrix ret, bas = *this;
ret.clear(), ret.epsilon();
ll tim = rhs;
while (tim)
{
if (tim & 1LL)
ret = ret * bas;
bas = bas * bas;
tim >>= 1;
}
return ret;
}
} init, trans;
ll getGn(ll x, ll cm)
{
cmod = cm;
if (x == 0)
return a;
if (x == 1)
return b;
return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod;
}
ll getFn(ll pbase)
{
if (pbase == 2)
return 3;
if (pbase == 3)
return 8;
if (pbase == 5)
return 20;
if (pbase % 5 == 1 || pbase % 5 == 4)
return pbase - 1;
return 2 * pbase + 2;
}
ll getLoopLen(ll x)
{
tot = 0;
ll acc = x;
for (int i = 2; 1LL * i * i <= acc; i++)
if (acc % i == 0)
{
loops[++tot] = getFn(i);
ll cnt = 1;
while (acc % i == 0)
acc /= i, cnt *= i;
cnt /= i, loops[tot] *= cnt;
}
if (acc != 1)
loops[++tot] = getFn(acc);
ll loop = loops[1];
for (int i = 2; i <= tot; i++)
loop = (loop / __gcd(loop, loops[i])) * loops[i];
return loop;
}
ll solve(ll n_, ll k_, ll cm)
{
if (k_ == 0)
return n_ % cm;
ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop);
ll res = getGn(now, cm);
return res;
}
int main()
{
// fileIO("hakugai");
scanf("%d", &T);
trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1;
init.clear(), init[0][0] = 0, init[0][1] = 1;
while (T--)
{
scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod);
// get loop;
printf("%lld\n", solve(n, level, mod));
}
return 0;
}
// hakugai.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 50;
int T, a, b, n, level, mod;
typedef long long ll;
ll loops[MAX_N], tot, cmod;
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
struct matrix
{
ll mat[2][2];
ll *operator[](const int &id) { return mat[id]; }
void clear() { memset(mat, 0, sizeof(mat)); };
void epsilon() { mat[0][0] = mat[1][1] = 1; }
matrix operator*(const matrix &rhs)
{
matrix ret;
ret.clear();
for (int i = 0; i < 2; i++)
for (int k = 0; k < 2; k++)
if (mat[i][k])
for (int j = 0; j < 2; j++)
if (rhs.mat[k][j])
ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod;
return ret;
}
matrix operator^(const ll &rhs)
{
matrix ret, bas = *this;
ret.clear(), ret.epsilon();
ll tim = rhs;
while (tim)
{
if (tim & 1LL)
ret = ret * bas;
bas = bas * bas;
tim >>= 1;
}
return ret;
}
} init, trans;
ll getGn(ll x, ll cm)
{
cmod = cm;
if (x == 0)
return a;
if (x == 1)
return b;
return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod;
}
ll getFn(ll pbase)
{
if (pbase == 2)
return 3;
if (pbase == 3)
return 8;
if (pbase == 5)
return 20;
if (pbase % 5 == 1 || pbase % 5 == 4)
return pbase - 1;
return 2 * pbase + 2;
}
ll getLoopLen(ll x)
{
tot = 0;
ll acc = x;
for (int i = 2; 1LL * i * i <= acc; i++)
if (acc % i == 0)
{
loops[++tot] = getFn(i);
ll cnt = 1;
while (acc % i == 0)
acc /= i, cnt *= i;
cnt /= i, loops[tot] *= cnt;
}
if (acc != 1)
loops[++tot] = getFn(acc);
ll loop = loops[1];
for (int i = 2; i <= tot; i++)
loop = (loop / __gcd(loop, loops[i])) * loops[i];
return loop;
}
ll solve(ll n_, ll k_, ll cm)
{
if (k_ == 0)
return n_ % cm;
ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop);
ll res = getGn(now, cm);
return res;
}
int main()
{
// fileIO("hakugai");
scanf("%d", &T);
trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1;
init.clear(), init[0][0] = 0, init[0][1] = 1;
while (T--)
{
scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod);
// get loop;
printf("%lld\n", solve(n, level, mod));
}
return 0;
}
B – 夕张的改造
这题就比较神仙了。
考虑矩阵树定理,直接做没法限制 \(k\) 这个东西。人类智慧告诉我们,如果我们把未加入的边的边权赋为 \(x\),那么最后我们矩阵树求出来的东西肯定就是一个关于 \(x\) 的多项式,那么前 \(k + 1\) 项的系数就是各自的方案数。
我们让 \(x = 1, 2, \dots, n\) 求出点值,然后再用高斯消元差值出来得到多项式,再加前 \(k + 1\) 项系数即可。这个题我个人很喜欢。
const int MAX_N = 55, mod = 998244353;
int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N];
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
int fpow(int bas, int tim)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
for (int i = 0; i < n - 1; i++)
for (int j = i; j < n - 1; j++)
for (int j = i; j < n - 1; j++)
swap(mat[i][j], mat[key][j]);
int inv = fpow(mat[i][i], mod - 2);
for (int j = i + 1; j < n - 1; j++)
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i; k < n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
for (int i = 0; i < n - 1; i++)
res = 1LL * res * mat[i][i] % mod;
scanf("%d%d", &n, &limit);
for (int i = 1, fa; i < n; i++)
scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true;
for (int x = 1; x <= n; x++)
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
mat[i][j] = mod - x, mat[i][i] += x;
mat[i][j] = mod - 1, mat[i][i]++;
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
for (int j = 1; j < n; j++)
mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod;
mat[i][n] = poly_val[i + 1];
for (int i = 0; i < n; i++)
for (int j = i; j < n; j++)
for (int j = i; j <= n; j++)
swap(mat[i][j], mat[key][j]);
int inv = fpow(mat[i][i], mod - 2);
for (int j = 0; j < n; j++)
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i + 1; k <= n; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
for (int i = 0; i < n; i++)
poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod;
for (int i = 0; i <= limit; i++)
ans = (0LL + ans + poly_coeff[i]) % mod;
// kaisou.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 55, mod = 998244353;
int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N];
bool mp[MAX_N][MAX_N];
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int gauss()
{
int res = 1;
for (int i = 0; i < n - 1; i++)
{
int key = i;
for (int j = i; j < n - 1; j++)
if (mat[j][i] != 0)
{
key = j;
break;
}
if (key != i)
{
res = mod - res;
for (int j = i; j < n - 1; j++)
swap(mat[i][j], mat[key][j]);
}
int inv = fpow(mat[i][i], mod - 2);
for (int j = i + 1; j < n - 1; j++)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i; k < n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 0; i < n - 1; i++)
res = 1LL * res * mat[i][i] % mod;
return res;
}
int main()
{
// fileIO("kaisou");
scanf("%d%d", &n, &limit);
for (int i = 1, fa; i < n; i++)
scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true;
for (int x = 1; x <= n; x++)
{
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
if (i != j)
if (!mp[i][j])
mat[i][j] = mod - x, mat[i][i] += x;
else
mat[i][j] = mod - 1, mat[i][i]++;
poly_val[x] = gauss();
}
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
{
mat[i][0] = 1;
for (int j = 1; j < n; j++)
mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod;
mat[i][n] = poly_val[i + 1];
}
for (int i = 0; i < n; i++)
{
int key = i;
for (int j = i; j < n; j++)
if (mat[j][i] > 0)
{
key = j;
break;
}
if (key != i)
for (int j = i; j <= n; j++)
swap(mat[i][j], mat[key][j]);
int inv = fpow(mat[i][i], mod - 2);
for (int j = 0; j < n; j++)
if (i != j)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i + 1; k <= n; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 0; i < n; i++)
poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod;
int ans = 0;
for (int i = 0; i <= limit; i++)
ans = (0LL + ans + poly_coeff[i]) % mod;
printf("%d\n", ans);
return 0;
}
// kaisou.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 55, mod = 998244353;
int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N];
bool mp[MAX_N][MAX_N];
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int gauss()
{
int res = 1;
for (int i = 0; i < n - 1; i++)
{
int key = i;
for (int j = i; j < n - 1; j++)
if (mat[j][i] != 0)
{
key = j;
break;
}
if (key != i)
{
res = mod - res;
for (int j = i; j < n - 1; j++)
swap(mat[i][j], mat[key][j]);
}
int inv = fpow(mat[i][i], mod - 2);
for (int j = i + 1; j < n - 1; j++)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i; k < n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 0; i < n - 1; i++)
res = 1LL * res * mat[i][i] % mod;
return res;
}
int main()
{
// fileIO("kaisou");
scanf("%d%d", &n, &limit);
for (int i = 1, fa; i < n; i++)
scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true;
for (int x = 1; x <= n; x++)
{
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
if (i != j)
if (!mp[i][j])
mat[i][j] = mod - x, mat[i][i] += x;
else
mat[i][j] = mod - 1, mat[i][i]++;
poly_val[x] = gauss();
}
memset(mat, 0, sizeof(mat));
for (int i = 0; i < n; i++)
{
mat[i][0] = 1;
for (int j = 1; j < n; j++)
mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod;
mat[i][n] = poly_val[i + 1];
}
for (int i = 0; i < n; i++)
{
int key = i;
for (int j = i; j < n; j++)
if (mat[j][i] > 0)
{
key = j;
break;
}
if (key != i)
for (int j = i; j <= n; j++)
swap(mat[i][j], mat[key][j]);
int inv = fpow(mat[i][i], mod - 2);
for (int j = 0; j < n; j++)
if (i != j)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = i + 1; k <= n; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 0; i < n; i++)
poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod;
int ans = 0;
for (int i = 0; i <= limit; i++)
ans = (0LL + ans + poly_coeff[i]) % mod;
printf("%d\n", ans);
return 0;
}
D – 亚特兰大
这题比赛的时候用了一个比较猥琐的方法,过掉了数据不强的 70 分。正解还是很有意思的。
我们考虑把一条边的每个约数都拆成一条独立的边,然后我们枚举每一个出现过的约数 \(d\),记连通点对的个数为 \(f(d)\),那么答案就是 \(\sum_{d = 1} \mu(d) f(d)\)。这一个步骤非常的 nb。
有了这个之后我们就可以来算了。考虑到 \(q\) 很小,我们离线下来之后也分解出来,不过要标上时间。
我们先处理那些没动过的边,用并查集做。动完这些之后,我们就需要来处理询问的边。我们这个时候需要可撤销的并查集,用启发式合并来实现。最后累加到答案里面去。
我当天下午想了半个小时具体实现之后还是写不动,最后还是去看题解的代码了(wtcl)。
// #pragma GCC optimize(2)
const int MAX_N = 1e6 + 200;
int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N];
vector<node> frames[MAX_N], qframes[MAX_N];
vector<int> facts[MAX_N];
} org[MAX_N], qseg[MAX_N];
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
for (int i = 2; i < MAX_N; i++)
primes[++tot] = i, mu[i] = -1;
for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i];
for (int i = 2; i < MAX_N; i++)
for (int j = i; j <= 1e6; j += i)
int fx = find(x), fy = find(y);
mem[fy] = fx, sum += 1LL * size[fx] * size[fy];
size[fx] += size[fy], stk.push(fy);
int fy = stk.top(), fx = mem[fy];
size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0;
void insert(int w, int id, int time_frame)
node u = node{id, time_frame};
for (int stat = 0; stat < (1 << m); stat++)
for (int i = 0; i < m; i++)
frames[pans].push_back(u);
for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1)
int ctime = frames[x][i].time_frame;
while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime)
gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y);
ll pans = 1LL * mu[x] * sum;
for (int tot = i; tot <= gx; tot++)
for (int k = gx + 1; k < siz; k++)
visit[frames[x][k].time_frame] = x;
for (int k = 0; k <= q; k++)
scanf("%d", &n), sieve();
for (int i = 1; i <= n - 1; i++)
scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z);
for (int i = 1; i <= q; i++)
scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i;
for (int i = 1; i < n; i++)
for (int i = 1; i <= q; i++)
node u = node{qseg[i].x, qseg[i].y};
for (int j = i; j <= q; j++)
if (j == i || qseg[i].x != qseg[j].x)
for (int j = 1; j < i; j++)
if (qseg[i].x == qseg[j].x)
node v = node{qseg[i].x, org[qseg[i].x].z};
for (int j = 0; j < i; j++)
for (int i = 0; i <= q; i++)
for (int j = 0, siz = qframes[i].size(); j < siz; j++)
insert(qframes[i][j].time_frame, qframes[i][j].id, i);
for (int i = 1; i <= n; i++)
for (int i = 1; i <= 1e6; i++)
for (int i = 0; i <= q; i++)
printf("%lld\n", ans[i]);
// atoranta.cpp
// #pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e6 + 200;
typedef long long ll;
int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N];
ll ans[MAX_N], sum;
bool vis[MAX_N];
stack<int> stk;
struct node
{
int id, time_frame;
};
vector<node> frames[MAX_N], qframes[MAX_N];
vector<int> facts[MAX_N];
struct segment
{
int x, y, z;
} org[MAX_N], qseg[MAX_N];
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
void sieve()
{
mu[1] = 1;
for (int i = 2; i < MAX_N; i++)
{
if (!vis[i])
primes[++tot] = i, mu[i] = -1;
for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
{
vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i];
if (i % primes[j] == 0)
{
mu[i * primes[j]] = 0;
break;
}
}
}
for (int i = 2; i < MAX_N; i++)
if (!vis[i])
for (int j = i; j <= 1e6; j += i)
facts[j].push_back(i);
}
int find(int x)
{
while (mem[x])
x = mem[x];
return x;
}
void merge(int x, int y)
{
int fx = find(x), fy = find(y);
if (size[fx] < size[fy])
swap(fx, fy);
mem[fy] = fx, sum += 1LL * size[fx] * size[fy];
size[fx] += size[fy], stk.push(fy);
}
void undo()
{
int fy = stk.top(), fx = mem[fy];
size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0;
stk.pop();
}
void insert(int w, int id, int time_frame)
{
node u = node{id, time_frame};
int m = facts[w].size();
for (int stat = 0; stat < (1 << m); stat++)
{
int pans = 1;
for (int i = 0; i < m; i++)
if (stat & (1 << i))
pans *= facts[w][i];
frames[pans].push_back(u);
}
}
void solve(int x)
{
sum = 0;
for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1)
{
gx = i - 1;
int ctime = frames[x][i].time_frame;
while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime)
gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y);
ll pans = 1LL * mu[x] * sum;
if (ctime != -1)
{
ans[ctime] += pans;
for (int tot = i; tot <= gx; tot++)
undo();
}
else
{
for (int k = gx + 1; k < siz; k++)
visit[frames[x][k].time_frame] = x;
for (int k = 0; k <= q; k++)
if (visit[k] != x)
ans[k] += pans;
}
}
while (!stk.empty())
undo();
}
int main()
{
fileIO("atoranta");
scanf("%d", &n), sieve();
for (int i = 1; i <= n - 1; i++)
scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z);
scanf("%d", &q);
for (int i = 1; i <= q; i++)
scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i;
for (int i = 1; i < n; i++)
if (last_time[i] == 0)
insert(org[i].z, i, -1);
for (int i = 1; i <= q; i++)
{
node u = node{qseg[i].x, qseg[i].y};
// marks;
for (int j = i; j <= q; j++)
if (j == i || qseg[i].x != qseg[j].x)
qframes[j].push_back(u);
else
break;
bool flag = false;
for (int j = 1; j < i; j++)
if (qseg[i].x == qseg[j].x)
{
flag = true;
break;
}
// first;
if (!flag)
{
node v = node{qseg[i].x, org[qseg[i].x].z};
for (int j = 0; j < i; j++)
qframes[j].push_back(v);
}
}
for (int i = 0; i <= q; i++)
for (int j = 0, siz = qframes[i].size(); j < siz; j++)
insert(qframes[i][j].time_frame, qframes[i][j].id, i);
for (int i = 1; i <= n; i++)
size[i] = 1;
for (int i = 1; i <= 1e6; i++)
if (!frames[i].empty())
solve(i);
for (int i = 0; i <= q; i++)
printf("%lld\n", ans[i]);
return 0;
}
// atoranta.cpp
// #pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e6 + 200;
typedef long long ll;
int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N];
ll ans[MAX_N], sum;
bool vis[MAX_N];
stack<int> stk;
struct node
{
int id, time_frame;
};
vector<node> frames[MAX_N], qframes[MAX_N];
vector<int> facts[MAX_N];
struct segment
{
int x, y, z;
} org[MAX_N], qseg[MAX_N];
void fileIO(string src)
{
freopen((src + ".in").c_str(), "r", stdin);
freopen((src + ".out").c_str(), "w", stdout);
}
void sieve()
{
mu[1] = 1;
for (int i = 2; i < MAX_N; i++)
{
if (!vis[i])
primes[++tot] = i, mu[i] = -1;
for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
{
vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i];
if (i % primes[j] == 0)
{
mu[i * primes[j]] = 0;
break;
}
}
}
for (int i = 2; i < MAX_N; i++)
if (!vis[i])
for (int j = i; j <= 1e6; j += i)
facts[j].push_back(i);
}
int find(int x)
{
while (mem[x])
x = mem[x];
return x;
}
void merge(int x, int y)
{
int fx = find(x), fy = find(y);
if (size[fx] < size[fy])
swap(fx, fy);
mem[fy] = fx, sum += 1LL * size[fx] * size[fy];
size[fx] += size[fy], stk.push(fy);
}
void undo()
{
int fy = stk.top(), fx = mem[fy];
size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0;
stk.pop();
}
void insert(int w, int id, int time_frame)
{
node u = node{id, time_frame};
int m = facts[w].size();
for (int stat = 0; stat < (1 << m); stat++)
{
int pans = 1;
for (int i = 0; i < m; i++)
if (stat & (1 << i))
pans *= facts[w][i];
frames[pans].push_back(u);
}
}
void solve(int x)
{
sum = 0;
for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1)
{
gx = i - 1;
int ctime = frames[x][i].time_frame;
while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime)
gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y);
ll pans = 1LL * mu[x] * sum;
if (ctime != -1)
{
ans[ctime] += pans;
for (int tot = i; tot <= gx; tot++)
undo();
}
else
{
for (int k = gx + 1; k < siz; k++)
visit[frames[x][k].time_frame] = x;
for (int k = 0; k <= q; k++)
if (visit[k] != x)
ans[k] += pans;
}
}
while (!stk.empty())
undo();
}
int main()
{
fileIO("atoranta");
scanf("%d", &n), sieve();
for (int i = 1; i <= n - 1; i++)
scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z);
scanf("%d", &q);
for (int i = 1; i <= q; i++)
scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i;
for (int i = 1; i < n; i++)
if (last_time[i] == 0)
insert(org[i].z, i, -1);
for (int i = 1; i <= q; i++)
{
node u = node{qseg[i].x, qseg[i].y};
// marks;
for (int j = i; j <= q; j++)
if (j == i || qseg[i].x != qseg[j].x)
qframes[j].push_back(u);
else
break;
bool flag = false;
for (int j = 1; j < i; j++)
if (qseg[i].x == qseg[j].x)
{
flag = true;
break;
}
// first;
if (!flag)
{
node v = node{qseg[i].x, org[qseg[i].x].z};
for (int j = 0; j < i; j++)
qframes[j].push_back(v);
}
}
for (int i = 0; i <= q; i++)
for (int j = 0, siz = qframes[i].size(); j < siz; j++)
insert(qframes[i][j].time_frame, qframes[i][j].id, i);
for (int i = 1; i <= n; i++)
size[i] = 1;
for (int i = 1; i <= 1e6; i++)
if (!frames[i].empty())
solve(i);
for (int i = 0; i <= q; i++)
printf("%lld\n", ans[i]);
return 0;
}
[…] 具体见:https://kalorona.com/oi/fortuna-oj-pa-apr-19/ […]