Loading [MathJax]/extensions/tex2jax.js

「Fortuna OJ」Jul 8th – Group B 解题报告

今天这几题咋就那么毒瘤呢?

A – String

其实这道题考场上应该能做出来的,是一道很简单的计数问题。在考场上一看到字符串就懵了,不知道为什么我对字符串的字典序有阴影,现在看来就是一道傻逼题。

考虑这样计数:

  1. 枚举一个\(i\),考虑前\(i-1\)个字符与\(T\)串相同,然后第\(i\)个字符小于\(T[i]\),这样可以保证后面怎么放置字母都能满足要求。
  2. 在枚举了\(i\)的情况下,枚举字符\(ch\),范围在\([a, T[i])\),考虑以下两种情况对答案的贡献:
    1. 如果\(ch = S[i]\),那么后面需要变动的字符个数为\(k\)个,对答案的贡献:\[ {n – i \choose k} 25^{k} \]
    2. 如果不等于,那么后面需要变动的字符个数为\(k-1\)个,因为本位占了一个;对答案的贡献:\[ {n – i \choose k – 1} 25^{k – 1} \]
  3. 贡献之后,如果本位\(T[i] \neq S[i]\),那么把\(k–\),代表多固定了一个不同的字符。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// string.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1e5 + 2000, mod = 1e9 + 7;
ll n, k, level[MAX_N], level_inv[MAX_N], ans;
char S[MAX_N], T[MAX_N];
ll quick_pow(ll bas, ll tim)
{
if (tim < 0)
return 0;
ll ans = 1;
while (tim)
{
if (tim & 1)
ans = ans * bas % mod;
bas = bas * bas % mod;
tim >>= 1;
}
return ans;
}
ll comb(ll n_, ll k_) { return level[n_] * level_inv[n_ - k_] % mod * level_inv[k_] % mod; }
int main()
{
scanf("%lld%lld", &n, &k), level[0] = 1;
for (int i = 1; i <= n; i++)
level[i] = level[i - 1] * i % mod;
level_inv[n] = quick_pow(level[n], mod - 2);
for (int i = n - 1; i >= 1; i--)
level_inv[i] = level_inv[i + 1] * (i + 1) % mod;
level_inv[0] = 1;
scanf("%s", S + 1), scanf("%s", T + 1);
for (int i = 1; i <= n; i++)
{
for (int ch = 0; ch < T[i] - 'a'; ch++)
{
if (ch + 'a' == S[i])
(ans += comb(n - i, k) * quick_pow(25, k) % mod) %= mod;
else
(ans += comb(n - i, k - 1) * quick_pow(25, k - 1) % mod) %= mod;
}
if (S[i] != T[i])
k--;
}
printf("%lld", (ans + 1) % mod);
return 0;
}
// string.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 1e5 + 2000, mod = 1e9 + 7; ll n, k, level[MAX_N], level_inv[MAX_N], ans; char S[MAX_N], T[MAX_N]; ll quick_pow(ll bas, ll tim) { if (tim < 0) return 0; ll ans = 1; while (tim) { if (tim & 1) ans = ans * bas % mod; bas = bas * bas % mod; tim >>= 1; } return ans; } ll comb(ll n_, ll k_) { return level[n_] * level_inv[n_ - k_] % mod * level_inv[k_] % mod; } int main() { scanf("%lld%lld", &n, &k), level[0] = 1; for (int i = 1; i <= n; i++) level[i] = level[i - 1] * i % mod; level_inv[n] = quick_pow(level[n], mod - 2); for (int i = n - 1; i >= 1; i--) level_inv[i] = level_inv[i + 1] * (i + 1) % mod; level_inv[0] = 1; scanf("%s", S + 1), scanf("%s", T + 1); for (int i = 1; i <= n; i++) { for (int ch = 0; ch < T[i] - 'a'; ch++) { if (ch + 'a' == S[i]) (ans += comb(n - i, k) * quick_pow(25, k) % mod) %= mod; else (ans += comb(n - i, k - 1) * quick_pow(25, k - 1) % mod) %= mod; } if (S[i] != T[i]) k--; } printf("%lld", (ans + 1) % mod); return 0; }
// string.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 1e5 + 2000, mod = 1e9 + 7;

ll n, k, level[MAX_N], level_inv[MAX_N], ans;
char S[MAX_N], T[MAX_N];

ll quick_pow(ll bas, ll tim)
{
    if (tim < 0)
        return 0;
    ll ans = 1;
    while (tim)
    {
        if (tim & 1)
            ans = ans * bas % mod;
        bas = bas * bas % mod;
        tim >>= 1;
    }
    return ans;
}

ll comb(ll n_, ll k_) { return level[n_] * level_inv[n_ - k_] % mod * level_inv[k_] % mod; }

int main()
{
    scanf("%lld%lld", &n, &k), level[0] = 1;
    for (int i = 1; i <= n; i++)
        level[i] = level[i - 1] * i % mod;
    level_inv[n] = quick_pow(level[n], mod - 2);
    for (int i = n - 1; i >= 1; i--)
        level_inv[i] = level_inv[i + 1] * (i + 1) % mod;
    level_inv[0] = 1;
    scanf("%s", S + 1), scanf("%s", T + 1);

    for (int i = 1; i <= n; i++)
    {
        for (int ch = 0; ch < T[i] - 'a'; ch++)
        {
            if (ch + 'a' == S[i])
                (ans += comb(n - i, k) * quick_pow(25, k) % mod) %= mod;
            else
                (ans += comb(n - i, k - 1) * quick_pow(25, k - 1) % mod) %= mod;
        }
        if (S[i] != T[i])
            k--;
    }
    printf("%lld", (ans + 1) % mod);
    return 0;
}

B – Running

这道题是一道好题,我对数论中循环的东西目前了解的很少,现在正好来了解一下。

手玩样例可以发现规律:对于一个步数为\(step_i\)的人,它可以到的格子的编号一定是\(gcd(step_i, n)\)的倍数。利用这个性质,我们可以枚举\(n\)的因数\(d\),然后看能不能找到一个\(gcd(a_i, n)|d\),如果找到了,这意味着\(d\)的倍数会被标记。之后,对答案加上\(\varphi(\frac{n}{d})\)。

但为什么\(\varphi\)可以保证计数不重复呢?这个问题一开始也困扰了我,但是之后突然顿悟:考虑一种重复的情况\(d, k_1 d, k_2 d, k_3 d \dots\),列式:

\[ gcd(a_x, n) | d \\ gcd(a_x, n) | k_1 d \\ gcd(a_x, n) | k_2 d \\ \dots \]

现在推出:

\[ \sum_{i = 1}^r \varphi(\frac{n}{k_i d}), k_i d | n \\ \sum_{x|\frac{n}{d}} \varphi(\frac{n}{x}) = \frac{n}{d} \]

其中\(d\)为最小的、合法的(\( \exists i, gcd(a_i, n) | d \))因数,然后最后对答案的贡献就等于\(\frac{n}{d}\)。

证明得差不多了, 其他地方再去理解理解。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// running.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 60, MAX_DOM = 1e7 + 200;
ll n, m, arr[MAX_N], phi[MAX_DOM], prime[MAX_DOM], tot, ans;
unordered_map<ll, ll> ump;
bool vis[MAX_DOM];
ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a % b); }
ll varphi(ll num)
{
if (num < MAX_DOM)
return phi[num];
if (ump.count(num))
return ump[num];
ll ans = (num * (num + 1)) >> 1;
for (ll d = 2, gx; d <= num; d = gx + 1)
{
gx = num / (num / d);
ans -= (gx - d + 1) * varphi(num / d);
}
return ump[num] = ans;
}
void sieve()
{
phi[1] = 1;
for (ll i = 2; i < MAX_DOM; i++)
{
if (!vis[i])
prime[++tot] = i, phi[i] = i - 1;
for (ll j = 1; j <= tot && i * prime[j] < MAX_DOM; j++)
{
vis[i * prime[j]] = true;
if (i % prime[j] == 0)
{
phi[i * prime[j]] = phi[i] * prime[j];
break;
}
phi[i * prime[j]] = phi[i] * phi[prime[j]];
}
}
for (int i = 2; i < MAX_DOM; i++)
phi[i] += phi[i - 1];
}
void solve(int factor)
{
for (int i = 1; i <= m; i++)
if (factor % gcd(arr[i], n) == 0)
{
ans += varphi(n / factor) - varphi(n / factor - 1);
break;
}
}
int main()
{
freopen("running.in", "r", stdin);
freopen("running.out", "w", stdout);
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= m; i++)
scanf("%d", &arr[i]);
sieve();
for (int i = 1; i * i <= n; i++)
if (n % i == 0)
{
solve(i);
if (n / i != i)
solve(n / i);
}
printf("%lld", n - ans);
return 0;
}
// running.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 60, MAX_DOM = 1e7 + 200; ll n, m, arr[MAX_N], phi[MAX_DOM], prime[MAX_DOM], tot, ans; unordered_map<ll, ll> ump; bool vis[MAX_DOM]; ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a % b); } ll varphi(ll num) { if (num < MAX_DOM) return phi[num]; if (ump.count(num)) return ump[num]; ll ans = (num * (num + 1)) >> 1; for (ll d = 2, gx; d <= num; d = gx + 1) { gx = num / (num / d); ans -= (gx - d + 1) * varphi(num / d); } return ump[num] = ans; } void sieve() { phi[1] = 1; for (ll i = 2; i < MAX_DOM; i++) { if (!vis[i]) prime[++tot] = i, phi[i] = i - 1; for (ll j = 1; j <= tot && i * prime[j] < MAX_DOM; j++) { vis[i * prime[j]] = true; if (i % prime[j] == 0) { phi[i * prime[j]] = phi[i] * prime[j]; break; } phi[i * prime[j]] = phi[i] * phi[prime[j]]; } } for (int i = 2; i < MAX_DOM; i++) phi[i] += phi[i - 1]; } void solve(int factor) { for (int i = 1; i <= m; i++) if (factor % gcd(arr[i], n) == 0) { ans += varphi(n / factor) - varphi(n / factor - 1); break; } } int main() { freopen("running.in", "r", stdin); freopen("running.out", "w", stdout); scanf("%lld%lld", &n, &m); for (int i = 1; i <= m; i++) scanf("%d", &arr[i]); sieve(); for (int i = 1; i * i <= n; i++) if (n % i == 0) { solve(i); if (n / i != i) solve(n / i); } printf("%lld", n - ans); return 0; }
// running.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 60, MAX_DOM = 1e7 + 200;

ll n, m, arr[MAX_N], phi[MAX_DOM], prime[MAX_DOM], tot, ans;
unordered_map<ll, ll> ump;
bool vis[MAX_DOM];

ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a % b); }

ll varphi(ll num)
{
    if (num < MAX_DOM)
        return phi[num];
    if (ump.count(num))
        return ump[num];
    ll ans = (num * (num + 1)) >> 1;

    for (ll d = 2, gx; d <= num; d = gx + 1)
    {
        gx = num / (num / d);
        ans -= (gx - d + 1) * varphi(num / d);
    }
    return ump[num] = ans;
}

void sieve()
{
    phi[1] = 1;
    for (ll i = 2; i < MAX_DOM; i++)
    {
        if (!vis[i])
            prime[++tot] = i, phi[i] = i - 1;
        for (ll j = 1; j <= tot && i * prime[j] < MAX_DOM; j++)
        {
            vis[i * prime[j]] = true;
            if (i % prime[j] == 0)
            {
                phi[i * prime[j]] = phi[i] * prime[j];
                break;
            }
            phi[i * prime[j]] = phi[i] * phi[prime[j]];
        }
    }
    for (int i = 2; i < MAX_DOM; i++)
        phi[i] += phi[i - 1];
}

void solve(int factor)
{
    for (int i = 1; i <= m; i++)
        if (factor % gcd(arr[i], n) == 0)
        {
            ans += varphi(n / factor) - varphi(n / factor - 1);
            break;
        }
}

int main()
{
    freopen("running.in", "r", stdin);
    freopen("running.out", "w", stdout);
    scanf("%lld%lld", &n, &m);
    for (int i = 1; i <= m; i++)
        scanf("%d", &arr[i]);
    sieve();
    for (int i = 1; i * i <= n; i++)
        if (n % i == 0)
        {
            solve(i);
            if (n / i != i)
                solve(n / i);
        }
    printf("%lld", n - ans);
    return 0;
}

C – Tree

这道题真 tm 的烦。

正常来讲,可以想出一个\(O(n^3)\)的背包转移。但是会超时。考虑优化,记录子树大小和上限大小可以优化成\(O(n^2)\),因为点对的个数是\(O(n^2)\)级别的。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// tree.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 3005;
int n, limit, head[MAX_N], current, siz[MAX_N], dp[MAX_N][MAX_N], val[MAX_N];
struct edge
{
int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
edges[current].to = dst, edges[current].nxt = head[src];
head[src] = current++;
}
void dfs(int u, int fa, int dep)
{
siz[u] = 1;
for (int i = head[u]; i != -1; i = edges[i].nxt)
if (edges[i].to != fa)
dfs(edges[i].to, u, dep - 1), siz[u] += siz[edges[i].to];
int bound = min(siz[u], dep);
for (int i = 1; i <= bound; i++)
dp[u][i] = val[u];
for (int i = head[u]; i != -1; i = edges[i].nxt)
if (edges[i].to != fa)
for (int k = bound; k >= 1; k--)
for (int a = 0; a + k <= bound; a++)
dp[u][a + k] = max(dp[u][a + k], dp[u][k] + dp[edges[i].to][a]);
}
int main()
{
memset(head, -1, sizeof(head));
scanf("%d%d", &n, &limit);
for (int i = 1; i <= n; i++)
scanf("%d", &val[i]);
for (int i = 1, u, v; i <= n - 1; i++)
scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u);
dfs(1, 0, limit);
int ans = 0;
for (int i = 1; i <= limit; i++)
ans = max(ans, dp[1][i]);
printf("%d", ans);
return 0;
}
// tree.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 3005; int n, limit, head[MAX_N], current, siz[MAX_N], dp[MAX_N][MAX_N], val[MAX_N]; struct edge { int to, nxt; } edges[MAX_N << 1]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } void dfs(int u, int fa, int dep) { siz[u] = 1; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa) dfs(edges[i].to, u, dep - 1), siz[u] += siz[edges[i].to]; int bound = min(siz[u], dep); for (int i = 1; i <= bound; i++) dp[u][i] = val[u]; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa) for (int k = bound; k >= 1; k--) for (int a = 0; a + k <= bound; a++) dp[u][a + k] = max(dp[u][a + k], dp[u][k] + dp[edges[i].to][a]); } int main() { memset(head, -1, sizeof(head)); scanf("%d%d", &n, &limit); for (int i = 1; i <= n; i++) scanf("%d", &val[i]); for (int i = 1, u, v; i <= n - 1; i++) scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u); dfs(1, 0, limit); int ans = 0; for (int i = 1; i <= limit; i++) ans = max(ans, dp[1][i]); printf("%d", ans); return 0; }
// tree.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 3005;

int n, limit, head[MAX_N], current, siz[MAX_N], dp[MAX_N][MAX_N], val[MAX_N];

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

void dfs(int u, int fa, int dep)
{
    siz[u] = 1;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa)
            dfs(edges[i].to, u, dep - 1), siz[u] += siz[edges[i].to];
    int bound = min(siz[u], dep);
    for (int i = 1; i <= bound; i++)
        dp[u][i] = val[u];
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa)
            for (int k = bound; k >= 1; k--)
                for (int a = 0; a + k <= bound; a++)
                    dp[u][a + k] = max(dp[u][a + k], dp[u][k] + dp[edges[i].to][a]);
}
int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d%d", &n, &limit);
    for (int i = 1; i <= n; i++)
        scanf("%d", &val[i]);
    for (int i = 1, u, v; i <= n - 1; i++)
        scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u);
    dfs(1, 0, limit);
    int ans = 0;
    for (int i = 1; i <= limit; i++)
        ans = max(ans, dp[1][i]);
    printf("%d", ans);
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *