Loading [MathJax]/extensions/tex2jax.js

HNOI 2018 省队集训 Day 1 – 解题报告

A – Tree

这道题粗看需要 Link Cut Tree,其实不然:如果我们仍然在把节点 \(1\) 作为根节点来处理子树信息、放入线段树,之后的询问我们只需要灵活的分类讨论即可。在实根为 \(root\) 时,\(u, v\) 之间的 \(LCA\) 显然是 \(lca(u, v), lca(root, u), lca(root, v)\) 之间深度最大的那一个。而修改权值和查询权值,只需要讨论两种祖先关系和平行关系即可。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// QOJ2030.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 3e5 + 200;
int n, q, seq[MAX_N], lft[MAX_N], rig[MAX_N], anti[MAX_N], root, fa[20][MAX_N], head[MAX_N], current;
int dep[MAX_N], ptot;
struct edge
{
int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
edges[current].to = dst, edges[current].nxt = head[src];
head[src] = current++;
}
namespace SegmentTree
{
ll nodes[MAX_N << 2], tag[MAX_N << 2];
#define lson (p << 1)
#define rson ((p << 1) | 1)
#define mid ((l + r) >> 1)
void build(int l, int r, int p)
{
if (l == r)
return (void)(nodes[p] = seq[anti[l]]);
build(l, mid, lson), build(mid + 1, r, rson);
nodes[p] = nodes[lson] + nodes[rson];
}
void pushdown(int p, int l, int r)
{
if (tag[p] != 0)
{
tag[lson] += tag[p], tag[rson] += tag[p];
nodes[lson] += 1LL * tag[p] * (mid - l + 1), nodes[rson] += 1LL * tag[p] * (r - mid);
tag[p] = 0;
}
}
void update(int ql, int qr, int l, int r, int p, ll val)
{
if (ql <= l && r <= qr)
{
nodes[p] += 1LL * (r - l + 1) * val, tag[p] += val;
return;
}
pushdown(p, l, r);
if (ql <= mid)
update(ql, qr, l, mid, lson, val);
if (mid < qr)
update(ql, qr, mid + 1, r, rson, val);
nodes[p] = nodes[lson] + nodes[rson];
}
ll query(int ql, int qr, int l, int r, int p)
{
if (ql <= l && r <= qr)
return nodes[p];
pushdown(p, l, r);
ll ret = 0;
if (ql <= mid)
ret += query(ql, qr, l, mid, lson);
if (mid < qr)
ret += query(ql, qr, mid + 1, r, rson);
return ret;
}
#undef mid
#undef rson
#undef lson
} // namespace SegmentTree
void dfs_init(int u, int fat)
{
dep[u] = dep[fat] + 1, fa[0][u] = fat, lft[u] = ++ptot, anti[ptot] = u;
for (int i = head[u]; i != -1; i = edges[i].nxt)
if (edges[i].to != fat)
dfs_init(edges[i].to, u);
rig[u] = ptot;
}
int getLCA(int x, int y)
{
if (dep[x] < dep[y])
swap(x, y);
for (int i = 19; i >= 0; i--)
if (dep[fa[i][x]] >= dep[y])
x = fa[i][x];
if (x == y)
return x;
for (int i = 19; i >= 0; i--)
if (fa[i][x] != fa[i][y])
x = fa[i][x], y = fa[i][y];
return fa[0][x];
}
int main()
{
memset(head, -1, sizeof(head)), root = 1;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++)
scanf("%d", &seq[i]);
for (int i = 1, u, v; i <= n - 1; i++)
scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u);
dfs_init(1, 0), root = 1, SegmentTree::build(1, n, 1);
for (int i = 1; i < 20; i++)
for (int j = 1; j <= n; j++)
fa[i][j] = fa[i - 1][fa[i - 1][j]];
while (q--)
{
int opt, x, y, z;
scanf("%d%d", &opt, &x);
if (opt == 1)
root = x;
else if (opt == 2)
{
scanf("%d%d", &y, &z);
int lca = getLCA(x, y), lca1 = getLCA(root, x), lca2 = getLCA(root, y);
int dlca = (dep[lca1] > dep[lca2] ? lca1 : lca2);
if (dep[lca] > dep[dlca])
dlca = lca;
if (dlca == root)
SegmentTree::update(1, n, 1, n, 1, z);
else if (lft[dlca] <= lft[root] && rig[root] <= rig[dlca])
{
int u = root;
for (int i = 19; i >= 0; i--)
if (dep[fa[i][u]] > dep[dlca])
u = fa[i][u];
SegmentTree::update(1, n, 1, n, 1, z);
SegmentTree::update(lft[u], rig[u], 1, n, 1, -z);
}
else
SegmentTree::update(lft[lca], rig[lca], 1, n, 1, z);
}
else if (opt == 3)
{
// something different;
// if there is an ancestor relationship;
if (x == root)
printf("%lld\n", SegmentTree::nodes[1]);
else if (lft[x] <= lft[root] && lft[root] <= rig[x])
{
int u = root;
for (int i = 19; i >= 0; i--)
if (dep[fa[i][u]] > dep[x])
u = fa[i][u];
printf("%lld\n", SegmentTree::nodes[1] - SegmentTree::query(lft[u], rig[u], 1, n, 1));
}
else
printf("%lld\n", SegmentTree::query(lft[x], rig[x], 1, n, 1));
}
}
return 0;
}
// QOJ2030.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 3e5 + 200; int n, q, seq[MAX_N], lft[MAX_N], rig[MAX_N], anti[MAX_N], root, fa[20][MAX_N], head[MAX_N], current; int dep[MAX_N], ptot; struct edge { int to, nxt; } edges[MAX_N << 1]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } namespace SegmentTree { ll nodes[MAX_N << 2], tag[MAX_N << 2]; #define lson (p << 1) #define rson ((p << 1) | 1) #define mid ((l + r) >> 1) void build(int l, int r, int p) { if (l == r) return (void)(nodes[p] = seq[anti[l]]); build(l, mid, lson), build(mid + 1, r, rson); nodes[p] = nodes[lson] + nodes[rson]; } void pushdown(int p, int l, int r) { if (tag[p] != 0) { tag[lson] += tag[p], tag[rson] += tag[p]; nodes[lson] += 1LL * tag[p] * (mid - l + 1), nodes[rson] += 1LL * tag[p] * (r - mid); tag[p] = 0; } } void update(int ql, int qr, int l, int r, int p, ll val) { if (ql <= l && r <= qr) { nodes[p] += 1LL * (r - l + 1) * val, tag[p] += val; return; } pushdown(p, l, r); if (ql <= mid) update(ql, qr, l, mid, lson, val); if (mid < qr) update(ql, qr, mid + 1, r, rson, val); nodes[p] = nodes[lson] + nodes[rson]; } ll query(int ql, int qr, int l, int r, int p) { if (ql <= l && r <= qr) return nodes[p]; pushdown(p, l, r); ll ret = 0; if (ql <= mid) ret += query(ql, qr, l, mid, lson); if (mid < qr) ret += query(ql, qr, mid + 1, r, rson); return ret; } #undef mid #undef rson #undef lson } // namespace SegmentTree void dfs_init(int u, int fat) { dep[u] = dep[fat] + 1, fa[0][u] = fat, lft[u] = ++ptot, anti[ptot] = u; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fat) dfs_init(edges[i].to, u); rig[u] = ptot; } int getLCA(int x, int y) { if (dep[x] < dep[y]) swap(x, y); for (int i = 19; i >= 0; i--) if (dep[fa[i][x]] >= dep[y]) x = fa[i][x]; if (x == y) return x; for (int i = 19; i >= 0; i--) if (fa[i][x] != fa[i][y]) x = fa[i][x], y = fa[i][y]; return fa[0][x]; } int main() { memset(head, -1, sizeof(head)), root = 1; scanf("%d%d", &n, &q); for (int i = 1; i <= n; i++) scanf("%d", &seq[i]); for (int i = 1, u, v; i <= n - 1; i++) scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u); dfs_init(1, 0), root = 1, SegmentTree::build(1, n, 1); for (int i = 1; i < 20; i++) for (int j = 1; j <= n; j++) fa[i][j] = fa[i - 1][fa[i - 1][j]]; while (q--) { int opt, x, y, z; scanf("%d%d", &opt, &x); if (opt == 1) root = x; else if (opt == 2) { scanf("%d%d", &y, &z); int lca = getLCA(x, y), lca1 = getLCA(root, x), lca2 = getLCA(root, y); int dlca = (dep[lca1] > dep[lca2] ? lca1 : lca2); if (dep[lca] > dep[dlca]) dlca = lca; if (dlca == root) SegmentTree::update(1, n, 1, n, 1, z); else if (lft[dlca] <= lft[root] && rig[root] <= rig[dlca]) { int u = root; for (int i = 19; i >= 0; i--) if (dep[fa[i][u]] > dep[dlca]) u = fa[i][u]; SegmentTree::update(1, n, 1, n, 1, z); SegmentTree::update(lft[u], rig[u], 1, n, 1, -z); } else SegmentTree::update(lft[lca], rig[lca], 1, n, 1, z); } else if (opt == 3) { // something different; // if there is an ancestor relationship; if (x == root) printf("%lld\n", SegmentTree::nodes[1]); else if (lft[x] <= lft[root] && lft[root] <= rig[x]) { int u = root; for (int i = 19; i >= 0; i--) if (dep[fa[i][u]] > dep[x]) u = fa[i][u]; printf("%lld\n", SegmentTree::nodes[1] - SegmentTree::query(lft[u], rig[u], 1, n, 1)); } else printf("%lld\n", SegmentTree::query(lft[x], rig[x], 1, n, 1)); } } return 0; }
// QOJ2030.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 3e5 + 200;

int n, q, seq[MAX_N], lft[MAX_N], rig[MAX_N], anti[MAX_N], root, fa[20][MAX_N], head[MAX_N], current;
int dep[MAX_N], ptot;

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

namespace SegmentTree
{

ll nodes[MAX_N << 2], tag[MAX_N << 2];
#define lson (p << 1)
#define rson ((p << 1) | 1)
#define mid ((l + r) >> 1)

void build(int l, int r, int p)
{
    if (l == r)
        return (void)(nodes[p] = seq[anti[l]]);
    build(l, mid, lson), build(mid + 1, r, rson);
    nodes[p] = nodes[lson] + nodes[rson];
}

void pushdown(int p, int l, int r)
{
    if (tag[p] != 0)
    {
        tag[lson] += tag[p], tag[rson] += tag[p];
        nodes[lson] += 1LL * tag[p] * (mid - l + 1), nodes[rson] += 1LL * tag[p] * (r - mid);
        tag[p] = 0;
    }
}

void update(int ql, int qr, int l, int r, int p, ll val)
{
    if (ql <= l && r <= qr)
    {
        nodes[p] += 1LL * (r - l + 1) * val, tag[p] += val;
        return;
    }
    pushdown(p, l, r);
    if (ql <= mid)
        update(ql, qr, l, mid, lson, val);
    if (mid < qr)
        update(ql, qr, mid + 1, r, rson, val);
    nodes[p] = nodes[lson] + nodes[rson];
}

ll query(int ql, int qr, int l, int r, int p)
{
    if (ql <= l && r <= qr)
        return nodes[p];
    pushdown(p, l, r);
    ll ret = 0;
    if (ql <= mid)
        ret += query(ql, qr, l, mid, lson);
    if (mid < qr)
        ret += query(ql, qr, mid + 1, r, rson);
    return ret;
}

#undef mid
#undef rson
#undef lson

} // namespace SegmentTree

void dfs_init(int u, int fat)
{
    dep[u] = dep[fat] + 1, fa[0][u] = fat, lft[u] = ++ptot, anti[ptot] = u;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fat)
            dfs_init(edges[i].to, u);
    rig[u] = ptot;
}

int getLCA(int x, int y)
{
    if (dep[x] < dep[y])
        swap(x, y);
    for (int i = 19; i >= 0; i--)
        if (dep[fa[i][x]] >= dep[y])
            x = fa[i][x];
    if (x == y)
        return x;
    for (int i = 19; i >= 0; i--)
        if (fa[i][x] != fa[i][y])
            x = fa[i][x], y = fa[i][y];
    return fa[0][x];
}

int main()
{
    memset(head, -1, sizeof(head)), root = 1;
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++)
        scanf("%d", &seq[i]);
    for (int i = 1, u, v; i <= n - 1; i++)
        scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u);
    dfs_init(1, 0), root = 1, SegmentTree::build(1, n, 1);
    for (int i = 1; i < 20; i++)
        for (int j = 1; j <= n; j++)
            fa[i][j] = fa[i - 1][fa[i - 1][j]];
    while (q--)
    {
        int opt, x, y, z;
        scanf("%d%d", &opt, &x);
        if (opt == 1)
            root = x;
        else if (opt == 2)
        {
            scanf("%d%d", &y, &z);
            int lca = getLCA(x, y), lca1 = getLCA(root, x), lca2 = getLCA(root, y);
            int dlca = (dep[lca1] > dep[lca2] ? lca1 : lca2);
            if (dep[lca] > dep[dlca])
                dlca = lca;
            if (dlca == root)
                SegmentTree::update(1, n, 1, n, 1, z);
            else if (lft[dlca] <= lft[root] && rig[root] <= rig[dlca])
            {
                int u = root;
                for (int i = 19; i >= 0; i--)
                    if (dep[fa[i][u]] > dep[dlca])
                        u = fa[i][u];

                SegmentTree::update(1, n, 1, n, 1, z);
                SegmentTree::update(lft[u], rig[u], 1, n, 1, -z);
            }
            else
                SegmentTree::update(lft[lca], rig[lca], 1, n, 1, z);
        }
        else if (opt == 3)
        {
            // something different;
            // if there is an ancestor relationship;
            if (x == root)
                printf("%lld\n", SegmentTree::nodes[1]);
            else if (lft[x] <= lft[root] && lft[root] <= rig[x])
            {
                int u = root;
                for (int i = 19; i >= 0; i--)
                    if (dep[fa[i][u]] > dep[x])
                        u = fa[i][u];
                printf("%lld\n", SegmentTree::nodes[1] - SegmentTree::query(lft[u], rig[u], 1, n, 1));
            }
            else
                printf("%lld\n", SegmentTree::query(lft[x], rig[x], 1, n, 1));
        }
    }
    return 0;
}

B – Function

我们回顾一下式子:
\[ f(x, y) = \begin{cases} A_y & x = 1 \\ f(x-1, y) + A_y & y = 1 \text{ and } x \neq 1 \\ \min { f(x-1, y-1), f(x-1, y) } + A_y & \text{otherwise} \end{cases} \]
我们可以近似地认为,这个式子的意义为在网格图上从 \((x, y)\) 走到 \((1, 1)\) 的最短距离。可以大概猜到一个结论:这样的路径一定是先走一段连续的左上角,再停在一个相对较小的权值上直走向上。那么,假设停止时,\(y = i\),那么答案就是:
\[ \text{设 } p[n] = \sum_{i = 1}^n A_i, ans = A_i(x – y + i) + p[y] – p[i] \]
我们整理一下:
\[A_i(x – y) + A_i \cdot i + p[y] – p[i]\]
我们可以把 \((x – y)\) 看做自变量,然后把这些直线丢入平面空间中。我们需要把询问离线排序,然后我们只需要保留最小的直线,也就是下凸包,这样可以让答案最小,用单调栈维护,并在栈上二分在其定义域中的直线并求值即可。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// QOJ2031.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 5e5 + 200;
int n, q, stk[MAX_N << 1], tail;
double pos[MAX_N << 1];
ll ai[MAX_N], prefix[MAX_N], constSuf[MAX_N], ans[MAX_N];
struct query
{
int x, y, id;
bool operator<(const query &rhs) const { return y < rhs.y; }
} queries[MAX_N];
double calc(int x, int y) { return double(constSuf[y] - constSuf[x]) / double(ai[x] - ai[y]); }
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%lld", &ai[i]), prefix[i] = prefix[i - 1] + ai[i];
constSuf[i] = ai[i] * i - prefix[i];
}
scanf("%d", &q);
for (int i = 1; i <= q; i++)
scanf("%d%d", &queries[i].x, &queries[i].y), queries[i].id = i;
sort(queries + 1, queries + 1 + q);
int curt = 0;
for (int i = 1; i <= q; i++)
{
int delta = queries[i].x - queries[i].y;
while (curt + 1 <= queries[i].y)
{
curt++;
while (tail > 0 && ai[stk[tail]] >= ai[curt])
tail--;
while (tail > 1 && calc(curt, stk[tail]) > calc(stk[tail], stk[tail - 1]))
tail--;
stk[++tail] = curt;
if (tail > 1)
pos[n - tail] = calc(curt, stk[tail - 1]);
}
int idx = n - (lower_bound(pos + n - tail, pos + n - 1, delta) - pos);
ans[queries[i].id] = delta * ai[stk[idx]] + constSuf[stk[idx]] + prefix[curt];
}
for (int i = 1; i <= q; i++)
printf("%lld\n", ans[i]);
return 0;
}
// QOJ2031.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 5e5 + 200; int n, q, stk[MAX_N << 1], tail; double pos[MAX_N << 1]; ll ai[MAX_N], prefix[MAX_N], constSuf[MAX_N], ans[MAX_N]; struct query { int x, y, id; bool operator<(const query &rhs) const { return y < rhs.y; } } queries[MAX_N]; double calc(int x, int y) { return double(constSuf[y] - constSuf[x]) / double(ai[x] - ai[y]); } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%lld", &ai[i]), prefix[i] = prefix[i - 1] + ai[i]; constSuf[i] = ai[i] * i - prefix[i]; } scanf("%d", &q); for (int i = 1; i <= q; i++) scanf("%d%d", &queries[i].x, &queries[i].y), queries[i].id = i; sort(queries + 1, queries + 1 + q); int curt = 0; for (int i = 1; i <= q; i++) { int delta = queries[i].x - queries[i].y; while (curt + 1 <= queries[i].y) { curt++; while (tail > 0 && ai[stk[tail]] >= ai[curt]) tail--; while (tail > 1 && calc(curt, stk[tail]) > calc(stk[tail], stk[tail - 1])) tail--; stk[++tail] = curt; if (tail > 1) pos[n - tail] = calc(curt, stk[tail - 1]); } int idx = n - (lower_bound(pos + n - tail, pos + n - 1, delta) - pos); ans[queries[i].id] = delta * ai[stk[idx]] + constSuf[stk[idx]] + prefix[curt]; } for (int i = 1; i <= q; i++) printf("%lld\n", ans[i]); return 0; }
// QOJ2031.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 5e5 + 200;

int n, q, stk[MAX_N << 1], tail;
double pos[MAX_N << 1];
ll ai[MAX_N], prefix[MAX_N], constSuf[MAX_N], ans[MAX_N];

struct query
{
    int x, y, id;
    bool operator<(const query &rhs) const { return y < rhs.y; }
} queries[MAX_N];

double calc(int x, int y) { return double(constSuf[y] - constSuf[x]) / double(ai[x] - ai[y]); }

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%lld", &ai[i]), prefix[i] = prefix[i - 1] + ai[i];
        constSuf[i] = ai[i] * i - prefix[i];
    }
    scanf("%d", &q);
    for (int i = 1; i <= q; i++)
        scanf("%d%d", &queries[i].x, &queries[i].y), queries[i].id = i;
    sort(queries + 1, queries + 1 + q);
    int curt = 0;
    for (int i = 1; i <= q; i++)
    {
        int delta = queries[i].x - queries[i].y;
        while (curt + 1 <= queries[i].y)
        {
            curt++;
            while (tail > 0 && ai[stk[tail]] >= ai[curt])
                tail--;
            while (tail > 1 && calc(curt, stk[tail]) > calc(stk[tail], stk[tail - 1]))
                tail--;
            stk[++tail] = curt;
            if (tail > 1)
                pos[n - tail] = calc(curt, stk[tail - 1]);
        }
        int idx = n - (lower_bound(pos + n - tail, pos + n - 1, delta) - pos);
        ans[queries[i].id] = delta * ai[stk[idx]] + constSuf[stk[idx]] + prefix[curt];
    }
    for (int i = 1; i <= q; i++)
        printf("%lld\n", ans[i]);
    return 0;
}

C – Or

很显然的 DP 式子,设置 \(dp[i][j]\) 为前 \(i\) 个序列中已经选择了 \(j\) 位 \(1\) 的合法状态。正常的转移非常简单:
\[dp[i][j] = \sum_{c = 1}^j dp[i – 1][j – c] {j \choose c} 2^{j – c}\]
其中,\(c\) 为这次新增的 \(1\) 的个数。这样的转移是 \(\Theta(nk^2)\) 的。然而,我们发现这个式子的卷积形式非常明显:
\[ \begin{aligned}dp[i][j] &= \sum_{c = 1}^j \frac{ dp[i – 1][j – c] 2^{j – c} } { (j – c)! } \cdot \frac{j!}{c!} \\ \frac{dp[i][j]}{j!} &= \sum_{c = 1}^j \frac{dp[i – 1][j – c] 2^{j – c}}{(j – c)!} \cdot \frac{1}{c!} \end{aligned}\]
稍微设置一下:

\[\begin{aligned} A_i(x) &= \sum_{j = 0}^\infty \frac{dp[i][j]}{j!} x^j = \sum_{j = 0}^\infty x^j \sum_{c = 1}^j \frac{dp[i – 1][j – c]2^{j – c}}{(j – c)!} \cdot \frac{1}{c!} \\ F_i(x) &= \sum_{c = 0}^\infty \frac{dp[i – 1][c]2^c}{c!} x^c, G_i(x) = \sum_{c = 0}^\infty \frac{1}{c!} x^c \\ A_i(x) &= F_i(x)G_i(x) \end{aligned}\]
调用 \(n\) 次 NTT 即可算完。但是这样非常的缓慢,我们需要利用更多已有的信息。我们注意到:
\[\begin{aligned} A_i(2x) &= \sum_{c = 0}^\infty \frac{dp[i – 1][c]}{c!} (2x)^c = \sum_{c = 0}^\infty \frac{dp[i – 1][c]2^c}{c!} x^c \\ &= F_i(x) \end{aligned}\]
Splendid!调整一下:
\[A_i(x) = F_i(x)G_i(x) = A_i(2x)G_i(x)\]
我们可以递归来算,这样复杂度就是 \(\Theta(n \log n \log k)\) 的了。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// QOJ2032.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e5 + 200, mod = 998244353, G = 3;
int n, K, poly_bit, poly_siz, rev[MAX_N], fac[MAX_N], fac_inv[MAX_N], pow2[MAX_N];
int f[MAX_N], g[MAX_N];
int quick_pow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
const int Gi = quick_pow(G, mod - 2);
void ntt_initialize()
{
for (int i = 0; i < poly_siz; i++)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (poly_bit - 1)));
}
void ntt(int *arr, int dft)
{
for (int i = 0; i < poly_siz; i++)
if (i < rev[i])
swap(arr[i], arr[rev[i]]);
for (int step = 1; step < poly_siz; step <<= 1)
{
int omega_n = quick_pow((dft == 1 ? G : Gi), (mod - 1) / (step << 1));
for (int j = 0; j < poly_siz; j += (step << 1))
{
int omega_nk = 1;
for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod)
{
int t = 1LL * arr[k + step] * omega_nk % mod;
arr[k + step] = (0LL + arr[k] + mod - t) % mod;
arr[k] = (1LL * arr[k] + t) % mod;
}
}
}
if (dft == -1)
{
int inv = quick_pow(poly_siz, mod - 2);
for (int i = 0; i < poly_siz; i++)
arr[i] = 1LL * arr[i] * inv % mod;
}
}
void polyMultiply(int *A, int *B)
{
ntt(A, 1), ntt(B, 1);
for (int i = 0; i < poly_siz; i++)
A[i] = 1LL * A[i] * B[i] % mod;
ntt(A, -1);
for (int i = 0; i < poly_siz; i++)
B[i] = 0;
for (int i = K + 1; i < poly_siz; i++)
A[i] = 0;
}
void multiply(int *A, int coeff)
{
for (int i = 0, acc = 1; i <= K; i++, acc = 1LL * acc * coeff % mod)
A[i] = 1LL * A[i] * acc % mod;
}
void solve(int idx)
{
if (idx == 0)
return (void)(f[0] = 1);
if (idx & 1)
{
solve(idx - 1), g[0] = 0;
for (int i = 1; i <= K; i++)
g[i] = fac_inv[i];
multiply(g, quick_pow(2, idx - 1));
polyMultiply(f, g);
}
else
{
solve(idx >> 1);
for (int i = 0; i <= K; i++)
g[i] = f[i];
multiply(g, quick_pow(2, idx >> 1));
polyMultiply(f, g);
}
}
int main()
{
scanf("%d%d", &n, &K);
for (int i = fac[0] = 1; i <= K; i++)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[K] = quick_pow(fac[K], mod - 2);
for (int i = K - 1; i >= 0; i--)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
for (int i = pow2[0] = 1; i <= max(n, K); i++)
pow2[i] = 2LL * pow2[i - 1] % mod;
while ((1 << poly_bit) <= (K << 1))
poly_bit++;
poly_siz = (1 << poly_bit), ntt_initialize();
solve(n);
int ans = 0;
for (int i = n; i <= K; i++)
ans = (1LL * ans + 1LL * f[i] * fac[K] % mod * fac_inv[K - i] % mod) % mod;
printf("%d\n", ans % mod);
return 0;
}
// QOJ2032.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e5 + 200, mod = 998244353, G = 3; int n, K, poly_bit, poly_siz, rev[MAX_N], fac[MAX_N], fac_inv[MAX_N], pow2[MAX_N]; int f[MAX_N], g[MAX_N]; int quick_pow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } const int Gi = quick_pow(G, mod - 2); void ntt_initialize() { for (int i = 0; i < poly_siz; i++) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (poly_bit - 1))); } void ntt(int *arr, int dft) { for (int i = 0; i < poly_siz; i++) if (i < rev[i]) swap(arr[i], arr[rev[i]]); for (int step = 1; step < poly_siz; step <<= 1) { int omega_n = quick_pow((dft == 1 ? G : Gi), (mod - 1) / (step << 1)); for (int j = 0; j < poly_siz; j += (step << 1)) { int omega_nk = 1; for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod) { int t = 1LL * arr[k + step] * omega_nk % mod; arr[k + step] = (0LL + arr[k] + mod - t) % mod; arr[k] = (1LL * arr[k] + t) % mod; } } } if (dft == -1) { int inv = quick_pow(poly_siz, mod - 2); for (int i = 0; i < poly_siz; i++) arr[i] = 1LL * arr[i] * inv % mod; } } void polyMultiply(int *A, int *B) { ntt(A, 1), ntt(B, 1); for (int i = 0; i < poly_siz; i++) A[i] = 1LL * A[i] * B[i] % mod; ntt(A, -1); for (int i = 0; i < poly_siz; i++) B[i] = 0; for (int i = K + 1; i < poly_siz; i++) A[i] = 0; } void multiply(int *A, int coeff) { for (int i = 0, acc = 1; i <= K; i++, acc = 1LL * acc * coeff % mod) A[i] = 1LL * A[i] * acc % mod; } void solve(int idx) { if (idx == 0) return (void)(f[0] = 1); if (idx & 1) { solve(idx - 1), g[0] = 0; for (int i = 1; i <= K; i++) g[i] = fac_inv[i]; multiply(g, quick_pow(2, idx - 1)); polyMultiply(f, g); } else { solve(idx >> 1); for (int i = 0; i <= K; i++) g[i] = f[i]; multiply(g, quick_pow(2, idx >> 1)); polyMultiply(f, g); } } int main() { scanf("%d%d", &n, &K); for (int i = fac[0] = 1; i <= K; i++) fac[i] = 1LL * fac[i - 1] * i % mod; fac_inv[K] = quick_pow(fac[K], mod - 2); for (int i = K - 1; i >= 0; i--) fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; for (int i = pow2[0] = 1; i <= max(n, K); i++) pow2[i] = 2LL * pow2[i - 1] % mod; while ((1 << poly_bit) <= (K << 1)) poly_bit++; poly_siz = (1 << poly_bit), ntt_initialize(); solve(n); int ans = 0; for (int i = n; i <= K; i++) ans = (1LL * ans + 1LL * f[i] * fac[K] % mod * fac_inv[K - i] % mod) % mod; printf("%d\n", ans % mod); return 0; }
// QOJ2032.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200, mod = 998244353, G = 3;

int n, K, poly_bit, poly_siz, rev[MAX_N], fac[MAX_N], fac_inv[MAX_N], pow2[MAX_N];
int f[MAX_N], g[MAX_N];

int quick_pow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

const int Gi = quick_pow(G, mod - 2);

void ntt_initialize()
{
    for (int i = 0; i < poly_siz; i++)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (poly_bit - 1)));
}

void ntt(int *arr, int dft)
{
    for (int i = 0; i < poly_siz; i++)
        if (i < rev[i])
            swap(arr[i], arr[rev[i]]);
    for (int step = 1; step < poly_siz; step <<= 1)
    {
        int omega_n = quick_pow((dft == 1 ? G : Gi), (mod - 1) / (step << 1));
        for (int j = 0; j < poly_siz; j += (step << 1))
        {
            int omega_nk = 1;
            for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod)
            {
                int t = 1LL * arr[k + step] * omega_nk % mod;
                arr[k + step] = (0LL + arr[k] + mod - t) % mod;
                arr[k] = (1LL * arr[k] + t) % mod;
            }
        }
    }
    if (dft == -1)
    {
        int inv = quick_pow(poly_siz, mod - 2);
        for (int i = 0; i < poly_siz; i++)
            arr[i] = 1LL * arr[i] * inv % mod;
    }
}

void polyMultiply(int *A, int *B)
{
    ntt(A, 1), ntt(B, 1);
    for (int i = 0; i < poly_siz; i++)
        A[i] = 1LL * A[i] * B[i] % mod;
    ntt(A, -1);
    for (int i = 0; i < poly_siz; i++)
        B[i] = 0;
    for (int i = K + 1; i < poly_siz; i++)
        A[i] = 0;
}

void multiply(int *A, int coeff)
{
    for (int i = 0, acc = 1; i <= K; i++, acc = 1LL * acc * coeff % mod)
        A[i] = 1LL * A[i] * acc % mod;
}

void solve(int idx)
{
    if (idx == 0)
        return (void)(f[0] = 1);
    if (idx & 1)
    {
        solve(idx - 1), g[0] = 0;
        for (int i = 1; i <= K; i++)
            g[i] = fac_inv[i];
        multiply(g, quick_pow(2, idx - 1));
        polyMultiply(f, g);
    }
    else
    {
        solve(idx >> 1);
        for (int i = 0; i <= K; i++)
            g[i] = f[i];
        multiply(g, quick_pow(2, idx >> 1));
        polyMultiply(f, g);
    }
}

int main()
{
    scanf("%d%d", &n, &K);
    for (int i = fac[0] = 1; i <= K; i++)
        fac[i] = 1LL * fac[i - 1] * i % mod;
    fac_inv[K] = quick_pow(fac[K], mod - 2);
    for (int i = K - 1; i >= 0; i--)
        fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
    for (int i = pow2[0] = 1; i <= max(n, K); i++)
        pow2[i] = 2LL * pow2[i - 1] % mod;
    while ((1 << poly_bit) <= (K << 1))
        poly_bit++;
    poly_siz = (1 << poly_bit), ntt_initialize();
    solve(n);
    int ans = 0;
    for (int i = n; i <= K; i++)
        ans = (1LL * ans + 1LL * f[i] * fac[K] % mod * fac_inv[K - i] % mod) % mod;
    printf("%d\n", ans % mod);
    return 0;
}

总结

我认为这是一套比较好的入门省选训练题,不会过于毒瘤也不会过于偏向联赛,科技题也不会过于裸露,而签到题也很友好。

Leave a Reply

Your email address will not be published. Required fields are marked *