主要思路
先建好树,大概思考一下应该可以推导出:倍增找标记点,然后根到区间修改。然而正解也是这样写,只不过细节有点多,我在这里提一下:回答询问\(x\)时,先把他父亲提到根,然后找最近的点;找到了就把深度大的全部打标记反过来。
代码
// LOJ2187.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = (5e5 + 200) * 3;
int ch[MAX_N][2], fa[MAX_N], siz[MAX_N], idx[MAX_N][3], tag[MAX_N], sum[MAX_N], val[MAX_N];
int head[MAX_N], current, n, q;
struct edge
{
int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
edges[current].to = dst, edges[current].nxt = head[src];
head[src] = current++;
}
#define lson (ch[p][0])
#define rson (ch[p][1])
// LCT;
int check(int p) { return ch[fa[p]][1] == p; }
bool isRoot(int p) { return ch[fa[p]][1] != p && ch[fa[p]][0] != p; }
void pushup(int p)
{
siz[p] = siz[lson] + siz[rson] + 1;
idx[p][1] = idx[rson][1], idx[p][2] = idx[rson][2];
if (!idx[p][1])
{
if (sum[p] != 1)
idx[p][1] = p;
else
idx[p][1] = idx[lson][1];
}
if (!idx[p][2])
{
if (sum[p] != 2)
idx[p][2] = p;
else
idx[p][2] = idx[lson][2];
}
}
void sumup(int p, int value)
{
sum[p] += value, val[p] = sum[p] > 1;
swap(idx[p][1], idx[p][2]), tag[p] += value;
}
void pushdown(int p)
{
if (tag[p] != 0)
{
if (lson)
sumup(lson, tag[p]);
if (rson)
sumup(rson, tag[p]);
tag[p] = 0;
}
}
void rotate(int x)
{
int y = fa[x], z = fa[y], dir = check(x), w = ch[x][dir ^ 1];
fa[x] = z;
if (!isRoot(y))
ch[z][check(y)] = x;
fa[y] = x, ch[x][dir ^ 1] = y;
fa[w] = y, ch[y][dir] = w;
pushup(y), pushup(x);
}
void update(int p)
{
if (!isRoot(p))
update(fa[p]);
pushdown(p);
}
void splay(int p)
{
update(p);
for (int fat = fa[p]; fat = fa[p], !isRoot(p); rotate(p))
if (!isRoot(fat))
rotate(fat == p ? fat : p);
pushup(p);
}
void access(int p)
{
for (int pre = 0; p != 0; pre = p, p = fa[p])
splay(p), rson = pre, pushup(p);
}
#undef rson
#undef lson
void dfs(int u)
{
for (int i = head[u]; i != -1; i = edges[i].nxt)
dfs(edges[i].to), sum[u] += val[edges[i].to];
if (u <= n)
val[u] = sum[u] > 1;
}
int main()
{
memset(head, -1, sizeof(head));
scanf("%d", &n);
for (int i = 1, x1, x2, x3; i <= n; i++)
{
scanf("%d%d%d", &x1, &x2, &x3), addpath(i, x1), addpath(i, x2), addpath(i, x3);
fa[x1] = fa[x2] = fa[x3] = i;
}
for (int i = n + 1; i <= 3 * n + 1; i++)
scanf("%d", &val[i]);
dfs(1), scanf("%d", &q);
int finalStat = val[1];
while (q--)
{
int x, fat;
scanf("%d", &x), fat = fa[x];
int targetDelta = val[x] ? -1 : 1;
access(fat), splay(fat);
int mid = idx[fat][val[x] ? 2 : 1];
if (mid != 0)
{
splay(mid), sumup(ch[mid][1], targetDelta), pushup(ch[mid][1]);
sum[mid] += targetDelta, val[mid] = sum[mid] > 1;
pushup(mid);
}
else
// passthrough;
finalStat ^= 1, sumup(fat, targetDelta), pushup(fat);
val[x] ^= 1, printf("%d\n", finalStat);
}
return 0;
}