「Fortuna OJ」Apr 19 省选 A 组 – 解题报告

A – Directory Traversal

没什么好说的,比较水的一个换根 DP。

// traversal.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200;

typedef long long ll;

int n, siz[MAX_N], head[MAX_N], current, wi[MAX_N], tot;
ll dp[MAX_N], g[MAX_N];
char name[20];
bool tag[MAX_N];

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

void dfs(int u, int fa)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
    {
        dfs(edges[i].to, u), siz[u] += siz[edges[i].to];
        dp[u] += dp[edges[i].to] + 1LL * siz[edges[i].to] * wi[edges[i].to] + siz[edges[i].to] - (tag[edges[i].to] == true);
    }
}

void collect(int u, int fa)
{
    if (fa != 0)
        g[u] = g[fa] + dp[fa] - dp[u] - 1LL * siz[u] * wi[u] - (siz[u] - (tag[u] == true)) + ((tot - siz[u]) * 3LL);
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        collect(edges[i].to, u);
}

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

int main()
{
    fileIO("traversal");
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 1, m, val; i <= n; i++)
    {
        scanf("%s%d", name, &m), wi[i] = strlen(name);
        if (m == 0)
            siz[i] = 1, tag[i] = true, tot++;
        else
            while (m--)
                scanf("%d", &val), addpath(i, val);
    }
    dfs(1, 0), collect(1, 0);
    ll ans = 2e18;
    for (int i = 1; i <= n; i++)
        ans = min(ans, dp[i] + g[i]);
    printf("%lld\n", ans);
    return 0;
}

B – 迫害 DJ

这道题其实在考场上能看得出来是找 Fibonacci 循环节,但是我不太会,所以打了个暴力走人。事后发现大家都不会。

首先,我们知道这道题本质是在求多个模域下的 \(g_x\)。\(g_x\) 可以被改写成 \(g_x = b \times Fib_{2n} – a \times Fib_{2(n – 1)}\)。所以问题被简化成求 Fibonacci 数列在若干个模域下的值。

求循环节比较人类智慧,设 \(mod = \prod p_i^{c_i}\),则:

\[ f(p) = \begin{cases} 3, p = 2 \\ 8, p = 3 \\ 5, p = 20 \\ p – 1, p \bmod 5 = 1, 4 \\ 2p + 2, \text{otherwise} \end{cases} \\ loop = \text{lcm} \{ f(p_i) \times p_i^{c_i – 1} \} \]

证明就不会了,有兴趣就去搜一搜吧。

知道这个循环节之后,考虑设计子任务 \(solve(n, k, mod)\),然后递归求即可。快速求 \(g_n\) 直接矩阵乘法即可。

// hakugai.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 50;

int T, a, b, n, level, mod;

typedef long long ll;

ll loops[MAX_N], tot, cmod;

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

struct matrix
{
    ll mat[2][2];
    ll *operator[](const int &id) { return mat[id]; }
    void clear() { memset(mat, 0, sizeof(mat)); };
    void epsilon() { mat[0][0] = mat[1][1] = 1; }
    matrix operator*(const matrix &rhs)
    {
        matrix ret;
        ret.clear();
        for (int i = 0; i < 2; i++)
            for (int k = 0; k < 2; k++)
                if (mat[i][k])
                    for (int j = 0; j < 2; j++)
                        if (rhs.mat[k][j])
                            ret[i][j] = (0LL + ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % cmod) % cmod;
        return ret;
    }
    matrix operator^(const ll &rhs)
    {
        matrix ret, bas = *this;
        ret.clear(), ret.epsilon();
        ll tim = rhs;
        while (tim)
        {
            if (tim & 1LL)
                ret = ret * bas;
            bas = bas * bas;
            tim >>= 1;
        }
        return ret;
    }
} init, trans;

ll getGn(ll x, ll cm)
{
    cmod = cm;
    if (x == 0)
        return a;
    if (x == 1)
        return b;
    return (1LL * ((init * (trans ^ (2 * x - 1)))[0][1]) * b % cmod + cmod - 1LL * ((init * (trans ^ (2 * (x - 1) - 1)))[0][1]) * a % cmod) % cmod;
}

ll getFn(ll pbase)
{
    if (pbase == 2)
        return 3;
    if (pbase == 3)
        return 8;
    if (pbase == 5)
        return 20;
    if (pbase % 5 == 1 || pbase % 5 == 4)
        return pbase - 1;
    return 2 * pbase + 2;
}

ll getLoopLen(ll x)
{
    tot = 0;
    ll acc = x;
    for (int i = 2; 1LL * i * i <= acc; i++)
        if (acc % i == 0)
        {
            loops[++tot] = getFn(i);
            ll cnt = 1;
            while (acc % i == 0)
                acc /= i, cnt *= i;
            cnt /= i, loops[tot] *= cnt;
        }
    if (acc != 1)
        loops[++tot] = getFn(acc);
    ll loop = loops[1];
    for (int i = 2; i <= tot; i++)
        loop = (loop / __gcd(loop, loops[i])) * loops[i];
    return loop;
}

ll solve(ll n_, ll k_, ll cm)
{
    if (k_ == 0)
        return n_ % cm;
    ll nxt_loop = getLoopLen(cm), now = solve(n_, k_ - 1, nxt_loop);
    ll res = getGn(now, cm);
    return res;
}

int main()
{
    // fileIO("hakugai");
    scanf("%d", &T);
    trans.clear(), trans[1][0] = 1, trans[1][1] = 1, trans[0][1] = 1;
    init.clear(), init[0][0] = 0, init[0][1] = 1;
    while (T--)
    {
        scanf("%d%d%d%d%d", &a, &b, &n, &level, &mod);
        // get loop;
        printf("%lld\n", solve(n, level, mod));
    }
    return 0;
}

B – 夕张的改造

这题就比较神仙了。

考虑矩阵树定理,直接做没法限制 \(k\) 这个东西。人类智慧告诉我们,如果我们把未加入的边的边权赋为 \(x\),那么最后我们矩阵树求出来的东西肯定就是一个关于 \(x\) 的多项式,那么前 \(k + 1\) 项的系数就是各自的方案数。

我们让 \(x = 1, 2, \dots, n\) 求出点值,然后再用高斯消元差值出来得到多项式,再加前 \(k + 1\) 项系数即可。这个题我个人很喜欢。

// kaisou.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 55, mod = 998244353;

int n, limit, mat[MAX_N][MAX_N], poly_val[MAX_N], poly_coeff[MAX_N];
bool mp[MAX_N][MAX_N];

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

int gauss()
{
    int res = 1;
    for (int i = 0; i < n - 1; i++)
    {
        int key = i;
        for (int j = i; j < n - 1; j++)
            if (mat[j][i] != 0)
            {
                key = j;
                break;
            }
        if (key != i)
        {
            res = mod - res;
            for (int j = i; j < n - 1; j++)
                swap(mat[i][j], mat[key][j]);
        }
        int inv = fpow(mat[i][i], mod - 2);
        for (int j = i + 1; j < n - 1; j++)
        {
            int rate = 1LL * mat[j][i] * inv % mod;
            for (int k = i; k < n - 1; k++)
                mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
        }
    }
    for (int i = 0; i < n - 1; i++)
        res = 1LL * res * mat[i][i] % mod;
    return res;
}

int main()
{
    // fileIO("kaisou");
    scanf("%d%d", &n, &limit);
    for (int i = 1, fa; i < n; i++)
        scanf("%d", &fa), mp[i][fa] = mp[fa][i] = true;
    for (int x = 1; x <= n; x++)
    {
        memset(mat, 0, sizeof(mat));
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                if (i != j)
                    if (!mp[i][j])
                        mat[i][j] = mod - x, mat[i][i] += x;
                    else
                        mat[i][j] = mod - 1, mat[i][i]++;
        poly_val[x] = gauss();
    }
    memset(mat, 0, sizeof(mat));
    for (int i = 0; i < n; i++)
    {
        mat[i][0] = 1;
        for (int j = 1; j < n; j++)
            mat[i][j] = 1LL * mat[i][j - 1] * (i + 1) % mod;
        mat[i][n] = poly_val[i + 1];
    }
    for (int i = 0; i < n; i++)
    {
        int key = i;
        for (int j = i; j < n; j++)
            if (mat[j][i] > 0)
            {
                key = j;
                break;
            }
        if (key != i)
            for (int j = i; j <= n; j++)
                swap(mat[i][j], mat[key][j]);
        int inv = fpow(mat[i][i], mod - 2);
        for (int j = 0; j < n; j++)
            if (i != j)
            {
                int rate = 1LL * mat[j][i] * inv % mod;
                for (int k = i + 1; k <= n; k++)
                    mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
            }
    }
    for (int i = 0; i < n; i++)
        poly_coeff[i] = 1LL * mat[i][n] * fpow(mat[i][i], mod - 2) % mod;
    int ans = 0;
    for (int i = 0; i <= limit; i++)
        ans = (0LL + ans + poly_coeff[i]) % mod;
    printf("%d\n", ans);
    return 0;
}

D – 亚特兰大

这题比赛的时候用了一个比较猥琐的方法,过掉了数据不强的 70 分。正解还是很有意思的。

我们考虑把一条边的每个约数都拆成一条独立的边,然后我们枚举每一个出现过的约数 \(d\),记连通点对的个数为 \(f(d)\),那么答案就是 \(\sum_{d = 1} \mu(d) f(d)\)。这一个步骤非常的 nb。

有了这个之后我们就可以来算了。考虑到 \(q\) 很小,我们离线下来之后也分解出来,不过要标上时间。

我们先处理那些没动过的边,用并查集做。动完这些之后,我们就需要来处理询问的边。我们这个时候需要可撤销的并查集,用启发式合并来实现。最后累加到答案里面去。

我当天下午想了半个小时具体实现之后还是写不动,最后还是去看题解的代码了(wtcl)。

// atoranta.cpp
// #pragma GCC optimize(2)
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e6 + 200;

typedef long long ll;

int n, q, mu[MAX_N], primes[MAX_N], tot, mem[MAX_N], size[MAX_N], last_time[MAX_N], visit[MAX_N];
ll ans[MAX_N], sum;
bool vis[MAX_N];
stack<int> stk;

struct node
{
    int id, time_frame;
};
vector<node> frames[MAX_N], qframes[MAX_N];
vector<int> facts[MAX_N];

struct segment
{
    int x, y, z;
} org[MAX_N], qseg[MAX_N];

void fileIO(string src)
{
    freopen((src + ".in").c_str(), "r", stdin);
    freopen((src + ".out").c_str(), "w", stdout);
}

void sieve()
{
    mu[1] = 1;
    for (int i = 2; i < MAX_N; i++)
    {
        if (!vis[i])
            primes[++tot] = i, mu[i] = -1;
        for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
        {
            vis[i * primes[j]] = true, mu[i * primes[j]] = -mu[i];
            if (i % primes[j] == 0)
            {
                mu[i * primes[j]] = 0;
                break;
            }
        }
    }
    for (int i = 2; i < MAX_N; i++)
        if (!vis[i])
            for (int j = i; j <= 1e6; j += i)
                facts[j].push_back(i);
}

int find(int x)
{
    while (mem[x])
        x = mem[x];
    return x;
}

void merge(int x, int y)
{
    int fx = find(x), fy = find(y);
    if (size[fx] < size[fy])
        swap(fx, fy);
    mem[fy] = fx, sum += 1LL * size[fx] * size[fy];
    size[fx] += size[fy], stk.push(fy);
}

void undo()
{
    int fy = stk.top(), fx = mem[fy];
    size[fx] -= size[fy], sum -= 1LL * size[fx] * size[fy], mem[fy] = 0;
    stk.pop();
}

void insert(int w, int id, int time_frame)
{
    node u = node{id, time_frame};
    int m = facts[w].size();
    for (int stat = 0; stat < (1 << m); stat++)
    {
        int pans = 1;
        for (int i = 0; i < m; i++)
            if (stat & (1 << i))
                pans *= facts[w][i];
        frames[pans].push_back(u);
    }
}

void solve(int x)
{
    sum = 0;
    for (int i = 0, gx, siz = frames[x].size(); i < siz; i = gx + 1)
    {
        gx = i - 1;
        int ctime = frames[x][i].time_frame;
        while (gx < siz - 1 && frames[x][gx + 1].time_frame == ctime)
            gx++, merge(org[frames[x][gx].id].x, org[frames[x][gx].id].y);
        ll pans = 1LL * mu[x] * sum;
        if (ctime != -1)
        {
            ans[ctime] += pans;
            for (int tot = i; tot <= gx; tot++)
                undo();
        }
        else
        {
            for (int k = gx + 1; k < siz; k++)
                visit[frames[x][k].time_frame] = x;
            for (int k = 0; k <= q; k++)
                if (visit[k] != x)
                    ans[k] += pans;
        }
    }
    while (!stk.empty())
        undo();
}

int main()
{
    fileIO("atoranta");
    scanf("%d", &n), sieve();
    for (int i = 1; i <= n - 1; i++)
        scanf("%d%d%d", &org[i].x, &org[i].y, &org[i].z);
    scanf("%d", &q);
    for (int i = 1; i <= q; i++)
        scanf("%d%d", &qseg[i].x, &qseg[i].y), last_time[qseg[i].x] = i;
    for (int i = 1; i < n; i++)
        if (last_time[i] == 0)
            insert(org[i].z, i, -1);
    for (int i = 1; i <= q; i++)
    {
        node u = node{qseg[i].x, qseg[i].y};
        // marks;
        for (int j = i; j <= q; j++)
            if (j == i || qseg[i].x != qseg[j].x)
                qframes[j].push_back(u);
            else
                break;
        bool flag = false;
        for (int j = 1; j < i; j++)
            if (qseg[i].x == qseg[j].x)
            {
                flag = true;
                break;
            }
        // first;
        if (!flag)
        {
            node v = node{qseg[i].x, org[qseg[i].x].z};
            for (int j = 0; j < i; j++)
                qframes[j].push_back(v);
        }
    }
    for (int i = 0; i <= q; i++)
        for (int j = 0, siz = qframes[i].size(); j < siz; j++)
            insert(qframes[i][j].time_frame, qframes[i][j].id, i);
    for (int i = 1; i <= n; i++)
        size[i] = 1;
    for (int i = 1; i <= 1e6; i++)
        if (!frames[i].empty())
            solve(i);
    for (int i = 0; i <= q; i++)
        printf("%lld\n", ans[i]);
    return 0;
}

 

One Comment

Leave a Reply

Your email address will not be published. Required fields are marked *