主要思路与推导
这道题是一道好题。
我们可以直接暴力 DP,考虑状态\(dp[i][j]\)为\(i\)个\(1\)与\(j\)个\(-1\)的答案。那么,我们可以从少一个\(1\)的情况和少一个\(-1\)的情况进行转移:
- 对于少一个\(1\)的情况,也就是\(dp[i – 1][j]\),我们把\(1\)放在这样序列的前面,这样可以让所有序列的最大前缀和都加一,所以贡献就是\(dp[i – 1][j] + {i + j – 1 \choose j}\)。
- 对于少一个\(-1\)的情况,也就是\(dp[i][j – 1]\),我们把\(-1\)放在这样的区间前面,会产生两种情况:对于贡献大于\(0\)的序列,\(-1\)的贡献就是这样的序列的个数;如果贡献本来就小于等于\(0\),那么就无贡献。所以我们还需要额外计算一个无贡献的序列的个数并加回来。
现在我们只需要解决无贡献的序列的个数就行。发现如果要统计这样的数,其实也需要知道\(1, -1\)的个数,所以我们需要再搞一个二维的 DP,状态跟之前差不多:
\[ zero[i][j] = zero[i – 1][j] + zero[i][j – 1] \]
特别的,当\(i>j\)时,取值为\(0\);当\(i = 0\),取值为\(1\)。所以这道题就搞定了,最后放上推好的式子:
\[ dp[i][j] = dp[i – 1][j] + {i + j – 1 \choose j} + dp[i][j – 1] – ({i + j – 1 \choose i} – zero[i][j – 1]) \]
代码
// CF1204E.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 4020, mod = 998244853;
int fac[MAX_N], fac_inv[MAX_N], n, m, k[MAX_N][MAX_N], dp[MAX_N][MAX_N];
int quick_pow(int bas, int tim)
{
int ans = 1;
while (tim)
{
if (tim & 1)
ans = 1LL * ans * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ans;
}
int combinator(int n_, int k_) { return 1LL * fac[n_] * fac_inv[n_ - k_] % mod * fac_inv[k_] % mod; }
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; i++)
k[0][i] = 1;
for (int i = 1; i <= n; i++)
for (int j = i; j <= m; j++)
k[i][j] = (1LL * k[i - 1][j] + 1LL * k[i][j - 1]) % mod;
for (int i = fac[0] = 1; i < MAX_N; i++)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[MAX_N - 1] = quick_pow(fac[MAX_N - 1], mod - 2);
for (int i = MAX_N - 2; i >= 1; i--)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
fac_inv[0] = 1;
for (int i = 1; i <= n; i++)
dp[i][0] = i;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
{
ll tmp = 1LL * combinator(i + j - 1, j) + 1LL * dp[i - 1][j] + 1LL * dp[i][j - 1] - (1LL * combinator(i + j - 1, i) - k[i][j - 1]);
while (tmp < 0)
tmp += mod;
tmp %= mod;
dp[i][j] = tmp;
}
printf("%d", dp[n][m]);
return 0;
}
// CF1204E.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 4020, mod = 998244853;
int fac[MAX_N], fac_inv[MAX_N], n, m, k[MAX_N][MAX_N], dp[MAX_N][MAX_N];
int quick_pow(int bas, int tim)
{
int ans = 1;
while (tim)
{
if (tim & 1)
ans = 1LL * ans * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ans;
}
int combinator(int n_, int k_) { return 1LL * fac[n_] * fac_inv[n_ - k_] % mod * fac_inv[k_] % mod; }
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; i++)
k[0][i] = 1;
for (int i = 1; i <= n; i++)
for (int j = i; j <= m; j++)
k[i][j] = (1LL * k[i - 1][j] + 1LL * k[i][j - 1]) % mod;
for (int i = fac[0] = 1; i < MAX_N; i++)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[MAX_N - 1] = quick_pow(fac[MAX_N - 1], mod - 2);
for (int i = MAX_N - 2; i >= 1; i--)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
fac_inv[0] = 1;
for (int i = 1; i <= n; i++)
dp[i][0] = i;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
{
ll tmp = 1LL * combinator(i + j - 1, j) + 1LL * dp[i - 1][j] + 1LL * dp[i][j - 1] - (1LL * combinator(i + j - 1, i) - k[i][j - 1]);
while (tmp < 0)
tmp += mod;
tmp %= mod;
dp[i][j] = tmp;
}
printf("%d", dp[n][m]);
return 0;
}
// CF1204E.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 4020, mod = 998244853; int fac[MAX_N], fac_inv[MAX_N], n, m, k[MAX_N][MAX_N], dp[MAX_N][MAX_N]; int quick_pow(int bas, int tim) { int ans = 1; while (tim) { if (tim & 1) ans = 1LL * ans * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ans; } int combinator(int n_, int k_) { return 1LL * fac[n_] * fac_inv[n_ - k_] % mod * fac_inv[k_] % mod; } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= m; i++) k[0][i] = 1; for (int i = 1; i <= n; i++) for (int j = i; j <= m; j++) k[i][j] = (1LL * k[i - 1][j] + 1LL * k[i][j - 1]) % mod; for (int i = fac[0] = 1; i < MAX_N; i++) fac[i] = 1LL * fac[i - 1] * i % mod; fac_inv[MAX_N - 1] = quick_pow(fac[MAX_N - 1], mod - 2); for (int i = MAX_N - 2; i >= 1; i--) fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; fac_inv[0] = 1; for (int i = 1; i <= n; i++) dp[i][0] = i; for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) { ll tmp = 1LL * combinator(i + j - 1, j) + 1LL * dp[i - 1][j] + 1LL * dp[i][j - 1] - (1LL * combinator(i + j - 1, i) - k[i][j - 1]); while (tmp < 0) tmp += mod; tmp %= mod; dp[i][j] = tmp; } printf("%d", dp[n][m]); return 0; }