P4428:「BJOI2018」二进制 – 题解

主要思路

打个表发现合法的区间需要满足以下条件之一:

  • 如果区间内的 \(1\) 的个数为奇数,则至少要有 \(2\) 个 \(0\)。
  • 区间内的 \(1\) 的个数为偶数。

我们可以考虑维护不合法的区间,因为不合法的区间特征明显一点,只有这种:

  • 如果区间内的 \(1\) 的个数为奇数,且 \(0\) 的个数 \(< 2\)。

我们可以考虑用线段树来维护这个信息。大概就是需要固定左右端点的选择,然后再用分治的方法处理跨两个子树的区间即可。

// P4428.cpp
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAX_N = 1e5 + 200;

int n, q, bits, str[MAX_N];

struct node
{
    ll clft[2][2], crig[2][2], fl[3], fr[3], ans, lftzero, rigzero;
    int sum[2];

    void clear()
    {
        memset(clft, 0, sizeof(clft)), memset(crig, 0, sizeof(crig));
        memset(fl, 0, sizeof(fl)), memset(fr, 0, sizeof(fr));
        lftzero = rigzero = false;
        sum[0] = sum[1] = ans = 0;
    }

    void load(int x)
    {
        clear();
        if (x)
            clft[0][1] = crig[0][1] = sum[1] = ans = fl[0] = fr[0] = 1;
        else
            clft[1][0] = crig[1][0] = sum[0] = lftzero = rigzero = 1;
    }
} nodes[MAX_N << 2];

node pushup(node ls, node rs)
{
    node ret;
    ret.clear();
    for (int i = 0; i < 2; i++)
        for (int j = 0; j < 2; j++)
        {
            ret.clft[i][j] += ls.clft[i][j];
            ret.crig[i][j] += rs.crig[i][j];
            if (i >= ls.sum[0])
                ret.clft[i][j] += rs.clft[i - ls.sum[0]][j ^ (ls.sum[1] & 1)];
            if (i >= rs.sum[0])
                ret.crig[i][j] += ls.crig[i - rs.sum[0]][j ^ (rs.sum[1] & 1)];
        }
    for (int i = 0; i < 3; i++)
    {
        ret.fl[i] += ls.fl[i], ret.fr[i] += rs.fr[i];
        if (ls.sum[1] == 0)
            ret.fl[min(2, i + ls.sum[0])] += rs.fl[i];
        if (rs.sum[1] == 0)
            ret.fr[min(2, i + rs.sum[0])] += ls.fr[i];
    }
    if (ls.sum[1] == 1 && rs.lftzero)
        ret.fl[min(2LL, ls.sum[0] + rs.lftzero)]++, ret.fl[2] += rs.lftzero - 1;
    if (rs.sum[1] == 1 && ls.rigzero)
        ret.fr[min(2LL, rs.sum[0] + ls.rigzero)]++, ret.fr[2] += ls.rigzero - 1;
    ret.lftzero = ls.lftzero + ((ls.sum[1] == 0) ? rs.lftzero : 0);
    ret.rigzero = rs.rigzero + ((rs.sum[1] == 0) ? ls.rigzero : 0);
    ret.sum[0] = ls.sum[0] + rs.sum[0], ret.sum[1] = ls.sum[1] + rs.sum[1];
    ret.ans = ls.ans + rs.ans;
    ret.ans += ls.crig[0][0] * (rs.clft[0][1] + rs.clft[1][1]);
    ret.ans += ls.crig[0][1] * (rs.clft[0][0] + rs.clft[1][0]);
    ret.ans += ls.crig[1][0] * rs.clft[0][1];
    ret.ans += ls.crig[1][1] * rs.clft[0][0];
    if (rs.lftzero)
        ret.ans += rs.lftzero * (ls.fr[0] + ls.fr[1] + ls.fr[2]) - ls.fr[0];
    if (ls.rigzero)
        ret.ans += ls.rigzero * (rs.fl[0] + rs.fl[1] + rs.fl[2]) - rs.fl[0];
    return ret;
}

#define lson (p << 1)
#define rson ((p << 1) | 1)
#define mid ((l + r) >> 1)

void build(int l, int r, int p)
{
    if (l == r)
        return (void)(nodes[p].load(str[l]));
    build(l, mid, lson), build(mid + 1, r, rson);
    nodes[p] = pushup(nodes[lson], nodes[rson]);
}

void update(int qx, int l, int r, int p)
{
    if (l == r)
        return (void)(nodes[p].load(str[l]));
    if (qx <= mid)
        update(qx, l, mid, lson);
    else
        update(qx, mid + 1, r, rson);
    nodes[p] = pushup(nodes[lson], nodes[rson]);
}

node query(int ql, int qr, int l, int r, int p)
{
    if (ql <= l && r <= qr)
        return nodes[p];
    bool lft = (ql <= mid), rig = (mid < qr);
    if (lft && rig)
        return pushup(query(ql, qr, l, mid, lson), query(ql, qr, mid + 1, r, rson));
    if (lft)
        return query(ql, qr, l, mid, lson);
    else
        return query(ql, qr, mid + 1, r, rson);
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", &str[i]);
    build(1, n, 1), scanf("%d", &q);
    while (q--)
    {
        int opt, l, r;
        scanf("%d%d", &opt, &l);
        if (opt == 1)
            str[l] ^= 1, update(l, 1, n, 1);
        else
            scanf("%d", &r), printf("%lld\n", 1LL * (r - l + 1) * (r - l + 2) / 2 - query(l, r, 1, n, 1).ans);
    }
    return 0;
}


Leave a Reply

Your email address will not be published. Required fields are marked *