「LibreOJ NOI Round #2」不等关系 – 题解

主要思路

这题好神仙啊。

首先介绍一个暴力的套路(虽然跟正解没关系,但我觉得碰上很多比赛的时候会用上):我们设置 \(f[i][j]\) 当前为前 \(i\) 个数中第 \(j\) 小的贡献,那么有

\[ f[i][j] = \begin{cases} \sum_{k = 1}^{i – 1} f[i – 1][k] (s_{i – 1} = <) \\ \sum_{k = i}^{j} f[i – 1][k] (s_{i – 1} = >) \end{cases} \]

这个方法没啥前途,所以我们换一种方法。

一定可以想到这是一个要满足很多条件的计数问题,所以就会不可避免的想到容斥。如果直接对 \(\Theta(n)\) 级别的条件进行容斥是不现实的。所以我们可以考虑转化一类符号,比如固定 \(>\),然后用容斥解决 \(<\)。所有的 \(<\) 在此时都可以被看作为是一种分段,剩下的段都已经确定了从大到小的顺序。如果要计算这样的形状的类型,我们需要做一个多重集排列数,然后根据 \(<\) 进行容斥之后再用 \(n!\) 算回来。

大概是:

\[ f[i] = \sum_{j = 0}^{i -1} [s_j = < \text{or} j = 0] (-1)^{cnt[i – 1] – cnt[j]} \frac{dp[j]}{(i – j)!} \]

分治 NTT 解决一下。

代码

// LOJ575.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 4e5 + 200, mod = 998244353, G = 3;

int n, poly_siz, poly_bit, rev[MAX_N], pre[MAX_N], fac[MAX_N], fac_inv[MAX_N], f[MAX_N], g[MAX_N], A[MAX_N], B[MAX_N];
char opt[MAX_N];

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

const int Gi = fpow(G, mod - 2);

void ntt_prep(int len)
{
    poly_siz = 1, poly_bit = 0;
    while (poly_siz <= len)
        poly_siz <<= 1, poly_bit++;
}

void ntt_initialize()
{
    for (int i = 0; i < poly_siz; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (poly_bit - 1));
}

void ntt(int *arr, int dft)
{
    for (int i = 0; i < poly_siz; i++)
        if (i < rev[i])
            swap(arr[i], arr[rev[i]]);
    for (int step = 1; step < poly_siz; step <<= 1)
    {
        int omega_n = fpow(dft == 1 ? G : Gi, (mod - 1) / (step << 1));
        for (int j = 0; j < poly_siz; j += (step << 1))
        {
            int omega_nk = 1;
            for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod)
            {
                int t = 1LL * omega_nk * arr[k + step] % mod;
                arr[k + step] = (0LL + arr[k] + mod - t) % mod, arr[k] = (0LL + arr[k] + t) % mod;
            }
        }
    }
    if (dft == -1)
    {
        int inv = fpow(poly_siz, mod - 2);
        for (int i = 0; i < poly_siz; i++)
            arr[i] = 1LL * arr[i] * inv % mod;
    }
}

void solve(int l, int r)
{
    if (l == r)
    {
        if (l == 0)
            f[0] = 1;
        else
            f[l] = (pre[l - 1] & 1) ? (mod - f[l]) % mod : f[l];
        if (opt[l] == '>' || l == 0)
            g[l] = (pre[l] & 1) ? (mod - f[l]) % mod : f[l];
        return;
    }
    int mid = (l + r) >> 1;
    solve(l, mid);
    int lft_len = mid - l + 1, rig_len = r - l;
    ntt_prep(lft_len + rig_len - 1), ntt_initialize();
    for (int i = 0; i < lft_len; i++)
        A[i] = g[i + l];
    for (int i = lft_len; i < poly_siz; i++)
        A[i] = 0;
    for (int i = 0; i < rig_len; i++)
        B[i] = fac_inv[i + 1];
    for (int i = rig_len; i < rig_len; i++)
        B[i] = 0;
    ntt(A, 1), ntt(B, 1);
    for (int i = 0; i < poly_siz; i++)
        A[i] = 1LL * A[i] * B[i] % mod;
    ntt(A, -1);
    for (int i = mid + 1; i <= r; i++)
        f[i] = (0LL + A[i - l - 1] + f[i]) % mod;
    solve(mid + 1, r);
}

int main()
{
    scanf("%s", opt + 1), n = strlen(opt + 1) + 1;
    for (int i = 1; i <= n; i++)
        pre[i] = pre[i - 1] + (opt[i] == '>');
    for (int i = fac[0] = 1; i <= n; i++)
        fac[i] = 1LL * fac[i - 1] * i % mod;
    fac_inv[n] = fpow(fac[n], mod - 2);
    for (int i = n - 1; i >= 0; i--)
        fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
    solve(0, n);
    printf("%lld\n", 1LL * f[n] * fac[n] % mod);
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *