P3233:「HNOI2014」世界树题解

思路

这道题对初学虚树的人来讲简直就是噩梦…之前打了一个 30 分暴力未果,遂转向题解。知道了一个叫做虚树的东西,然后抄了一下午的题解才 AC。

我先来解释一下虚树这个数据结构(?)。在一棵书上,会有关键点集和非关键点集,在一些问题中我们只需要用到关键点之间的关系,而非关键点集便不在那么重要。这个时候我们可以建立一个虚树。

建立虚树的过程相当之繁琐,我在这里不在详细讲,可以使用单调栈和 LCA 算法以极高的效率搞定。详细见:https://oi-wiki.org/ds/virtual-tree/

以下为代码:

// P3233.cpp
#include <cstdio>
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>
#define pr pair<int, int>
using namespace std;
const int MX_N = 300020;
int head[MX_N], current;
struct edge
{
    int to, nxt;
} edges[MX_N << 1];
int fa[MX_N], stfa[20][MX_N], n, dep[MX_N], anses[MX_N], id[MX_N], dfn = 0, q, m;
int tmpx, tmpy, st[MX_N], top = 1, tsiz[MX_N];
pr mx[MX_N];
bool vis[MX_N];
void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}
void preprocess()
{
    for (int i = 1; i <= n; i++)
        stfa[0][i] = fa[i];
    for (int tim = 1; tim < 20; tim++)
        for (int u = 1; u <= n; u++)
            stfa[tim][u] = stfa[tim - 1][stfa[tim - 1][u]];
}
int jump(int u, int p)
{
    for (int i = 0; i <= 19; i++)
        if ((p >> i) & 1)
            u = stfa[i][u];
    return u;
}
int getLca(int a, int b)
{
    // b is deeper;
    if (dep[a] > dep[b])
        swap(a, b);
    b = jump(b, dep[b] - dep[a]);
    if (a == b)
        return a;
    for (int tim = 19; tim >= 0; tim--)
        if (stfa[tim][a] != stfa[tim][b])
            a = stfa[tim][a], b = stfa[tim][b];
    return fa[a];
}
void dfs_fa(int u)
{
    id[u] = ++dfn;
    tsiz[u] = 1;
    dep[u] = dep[fa[u]] + 1;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (fa[u] != edges[i].to)
            fa[edges[i].to] = u, dfs_fa(edges[i].to), tsiz[u] += tsiz[edges[i].to];
}
bool compare(const int &a, const int &b) { return id[a] < id[b]; }
void dfs_1(int u)
{
    if (vis[u])
        mx[u] = make_pair(0, u);
    else
        mx[u] = make_pair(1e8, 0);
    for (int i = head[u]; i != -1; i = edges[i].nxt)
    {
        int to = edges[i].to;
        dfs_1(to);
        pr tmp = mx[to];
        tmp.first = dep[mx[to].second] - dep[u];
        mx[u] = min(mx[u], tmp);
    }
}
void dfs_2(int u)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
    {
        pr p = mx[u];
        p.first += dep[edges[i].to] - dep[u];
        mx[edges[i].to] = min(mx[edges[i].to], p);
        dfs_2(edges[i].to);
    }
    anses[mx[u].second] = max(anses[mx[u].second], tsiz[u]);
}
void dfs_3(int u)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
    {
        int x = mx[u].second, y = mx[edges[i].to].second;
        if (x != y)
        {
            int dist = dep[x] + dep[y] - (dep[getLca(x, y)] << 1);
            int z = jump(edges[i].to, (dist >> 1) - mx[edges[i].to].first);
            if (dist & 1)
                anses[x] -= tsiz[z];
            else
            {
                if (z != u && z != edges[i].to)
                    z = jump(edges[i].to, (dist >> 1) - mx[edges[i].to].first - (x < y));
                else if (z == u)
                    z = jump(edges[i].to, (dist >> 1) - mx[edges[i].to].first - 1);
                anses[x] -= tsiz[z];
            }
            if (edges[i].to != z)
                anses[y] += tsiz[z] - tsiz[edges[i].to];
        }
        dfs_3(edges[i].to);
    }
}
int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
        scanf("%d%d", &tmpx, &tmpy), addpath(tmpx, tmpy), addpath(tmpy, tmpx);
    dfs_fa(1);
    preprocess();
    scanf("%d", &q);
    while (q--)
    {
        current = 0;
        scanf("%d", &m);
        vector<int> harr, arrs;
        for (int i = 1; i <= m; i++)
            scanf("%d", &tmpx), vis[tmpx] = true, harr.push_back(tmpx), anses[tmpx] = 0, arrs.push_back(tmpx);
        sort(harr.begin(), harr.end(), compare);
        // start to build the virtual tree;
        // prep for the stack;
        st[top = 1] = 1, head[1] = -1;
        for (int i = 0; i < m; i++)
        {
            if (harr[i] == 1)
                continue;
            int curtpt = harr[i], lca = getLca(curtpt, st[top]);
            if (lca != st[top])
            {
                while (id[lca] < id[st[top - 1]])
                    addpath(st[top - 1], st[top]), top--;
                if (id[lca] > id[st[top - 1]])
                    head[lca] = -1, addpath(lca, st[top]), st[top] = lca;
                else
                    addpath(lca, st[top--]);
            }
            head[curtpt] = -1, st[++top] = curtpt;
        }
        for (int i = 1; i < top; i++)
            addpath(st[i], st[i + 1]);
        dfs_1(1), dfs_2(1), dfs_3(1);
        for (int i = 0; i < m; i++)
            printf("%d ", anses[arrs[i]]);
        printf("\n");
        for (int i = 0; i < m; i++)
            vis[arrs[i]] = false;
    }
    return 0;
}

 

POJ1417:True Liars 题解

这道题好毒瘤啊…本来想着可以直接写个并查集 A 掉没想到还需要背包 DP。我们一段一段来讲。

// POJ1417.cpp
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int maxn = 400, maxm = 2010;
int n, p1, p2, fa[maxm], tot, cnt[maxm], cur;
int dp[1205][maxn];
vector<int> T;
bool pre[1205][maxn];
struct team
{
    int sam, diff;
} nodes[maxm];
bool init()
{
    scanf("%d%d%d", &n, &p1, &p2);
    if (n == 0 && p1 == 0 && p2 == 0)
        return false;
    tot = p1 + p2;
    for (int i = 0; i < maxm; i++)
        fa[i] = i;
    return true;
}
int find(int x) { return x == fa[x] ? x : fa[x] = find(fa[x]); }

首先是声明数组和变量们,并且做初始化。之后我们写一个 solve 函数:

while (n--)
{
    int u, v;
    char opt[10];
    scanf("%d%d%s", &u, &v, opt);
    if (opt[0] == 'n')
        fa[find(u)] = find(v + tot), fa[find(u + tot)] = find(v);
    else
        fa[find(u)] = find(v), fa[find(u + tot)] = find(v + tot);
}

我们可以推理得出,如果操作为\(yes\),那么他们就是同类,亦而反之。在这里,就可能会形成一个并查集森林:有多余\(2\)个的类型,所以我们需要用背包 DP 来计算能不能凑出唯一的天使和恶魔配比。在此之前,我们先要找出这些森林中的树:

cur = 0;
memset(cnt, 0, sizeof(cnt));
for (int i = 1; i <= tot; i++)
{
    int root = find(i);
    if (cnt[root] == 0 && root <= tot)
        nodes[++cur] = (team){root, find(i + tot)};
    cnt[root]++;
}

找到没有被访问过的树,顺便统计子树大小。之后我们进行背包 DP。

memset(dp, 0, sizeof(dp));
dp[0][0] = 1;
for (int i = 1; i <= cur; i++)
    for (int j = 0; j <= p1; j++)
        if (dp[i - 1][j])
        {
            if (j + cnt[nodes[i].sam] <= p1)
            {
                dp[i][j + cnt[nodes[i].sam]] += dp[i - 1][j];
                pre[i][j + cnt[nodes[i].sam]] = true;
            }
            if (j + cnt[nodes[i].diff] <= p1)
            {
                dp[i][j + cnt[nodes[i].diff]] += dp[i - 1][j];
                pre[i][j + cnt[nodes[i].diff]] = false;
            }
        }

叠加可能的次数,最终如果答案为\(1\),意味着有唯一解。我们需要筛除多解和无解的情况,然后统计答案。

if (dp[cur][p1] != 1)
{
    puts("no");
    return;
}
int C = p1;
for (int i = cur; i >= 1; i--)
    if (pre[i][C])
        C -= cnt[nodes[i].sam], T.push_back(nodes[i].sam);
    else
        C -= cnt[nodes[i].diff], T.push_back(nodes[i].diff);
for (int i = 1; i <= tot; i++)
{
    int rt = find(i);
    if (find(T.begin(), T.end(), rt) != T.end())
        printf("%d\n", i);
}
T.clear();
printf("end\n");

最后完整代码附上:

// POJ1417.cpp
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int maxn = 400, maxm = 2010;
int n, p1, p2, fa[maxm], tot, cnt[maxm], cur;
int dp[1205][maxn];
vector<int> T;
bool pre[1205][maxn];
struct team
{
    int sam, diff;
} nodes[maxm];
bool init()
{
    scanf("%d%d%d", &n, &p1, &p2);
    if (n == 0 && p1 == 0 && p2 == 0)
        return false;
    tot = p1 + p2;
    for (int i = 0; i < maxm; i++)
        fa[i] = i;
    return true;
}
int find(int x) { return x == fa[x] ? x : fa[x] = find(fa[x]); }
// 1 divine, 2 devil
void solve()
{
    while (n--)
    {
        int u, v;
        char opt[10];
        scanf("%d%d%s", &u, &v, opt);
        if (opt[0] == 'n')
            fa[find(u)] = find(v + tot), fa[find(u + tot)] = find(v);
        else
            fa[find(u)] = find(v), fa[find(u + tot)] = find(v + tot);
    }
    cur = 0;
    memset(cnt, 0, sizeof(cnt));
    for (int i = 1; i <= tot; i++)
    {
        int root = find(i);
        if (cnt[root] == 0 && root <= tot)
            nodes[++cur] = (team){root, find(i + tot)};
        cnt[root]++;
    }
    memset(dp, 0, sizeof(dp));
    //memset(pre, 0, sizeof(pre));
    dp[0][0] = 1;
    for (int i = 1; i <= cur; i++)
        for (int j = 0; j <= p1; j++)
            if (dp[i - 1][j])
            {
                if (j + cnt[nodes[i].sam] <= p1)
                {
                    dp[i][j + cnt[nodes[i].sam]] += dp[i - 1][j];
                    pre[i][j + cnt[nodes[i].sam]] = true;
                }
                if (j + cnt[nodes[i].diff] <= p1)
                {
                    dp[i][j + cnt[nodes[i].diff]] += dp[i - 1][j];
                    pre[i][j + cnt[nodes[i].diff]] = false;
                }
            }
    if (dp[cur][p1] != 1)
    {
        puts("no");
        return;
    }
    int C = p1;
    for (int i = cur; i >= 1; i--)
        if (pre[i][C])
            C -= cnt[nodes[i].sam], T.push_back(nodes[i].sam);
        else
            C -= cnt[nodes[i].diff], T.push_back(nodes[i].diff);
    for (int i = 1; i <= tot; i++)
    {
        int rt = find(i);
        if (find(T.begin(), T.end(), rt) != T.end())
            printf("%d\n", i);
    }
    T.clear();
    printf("end\n");
}
int main()
{
    while (init())
        solve();
    return 0;
}

 

P2110:欢总喊楼记题解

神仙做法

我看题解之后非常心塞,竟然不用数位 DP !我们先来考虑部分解。当\(num<10\)时,答案就是\(num\)。如果\(num \geq 10\),那么答案就是\(9+\frac{x}{10}\)。我们来看,考虑一个数\(1921\)的答案,我们可以考虑\(192-\)部分里有\(\lfloor \frac{1920}{10} \rfloor\)对首尾相同的数。然后再加上位数为一时的答案。

Continue reading →