主要思路
首先,不考虑原问题,只考虑最优策略,显然是当前能匹配的串越长越好。所以可以考虑二分答案,算当前拼接次数下能拼出来的最长的串,再将其与 \(n\) 进行比较。
那我们怎么算这个东西呢?我们可以考虑算一个 \(DP\),设 \(f[a][b][L]\) 表示当前第 \(L\) 次拼上去一个开头字符为 \(a\) 的字符串、第 \(L + 1\) 次拼上去一个开头为 \(b\) 的字符串的最短大小,且这两个字符串拼起来在原串中并未出现。这个东西如果不考虑时间复杂度,可以用 Floyd 去算。在这里,发现这个 \(L\) 很大,但是 \(a, b\) 很小,所以可以尝试矩阵快速幂。
我们现在需要算的只是 \(f[][][1]\) 的数据。这个数据可以考虑建一颗 Trie 树,把所有的后缀塞进去然后暴力计数。但是这样肯定不行。
然而,有一个性质可以被利用。考虑计算 \(g[a][b][L]\),表示 \(T\) 中有多少种 \(a\) 开头、\(b\) 结尾的、长度为 \(L\) 的字符串(需要本质不同)。如果这个数量达到了 \(4^{L – 2}\),那么说明这个尺度里所有的字符串都出现过了,所以不存在不出现的情况;亦而反之。
这就可以引出一个性质,考虑不合法的情况:
\[ |T| \geq g[a][b][L] \geq 4^{L-2} \\ |T| \geq 4^{L-2} \]
这个 \(L\) 肯定很小,所以 Trie 之类的都只要 \(\Theta(L)\) 级别的即可。
代码
// CF461E.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e5 + 200; typedef long long ll; int m, ch[MAX_N * 20][4], calc[4][4][20], ptot, mlen; ll n; bool tag[MAX_N * 20]; char str[MAX_N]; ll fpow(ll bas, ll tim) { ll ret = 1; while (tim) { if (tim & 1) ret *= bas; bas *= bas; tim >>= 1; } return ret; } struct matrix { ll mat[4][4]; ll *operator[](const int &rhs) { return mat[rhs]; } matrix operator*(const matrix &rhs) { matrix ret; memset(ret.mat, 0x3f, sizeof(ret.mat)); for (int k = 0; k < 4; k++) for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) ret[i][j] = min(ret[i][j], mat[i][k] + rhs.mat[k][j]); return ret; } matrix operator^(const ll &rhs) { ll tim = rhs - 1; matrix ret = *this, bas = *this; while (tim) { if (tim & 1) ret = ret * bas; bas = bas * bas; tim >>= 1; } return ret; } } trans; void insert(int start_pos) { int p = 0, sc = str[start_pos] - 'A'; for (int i = start_pos; i <= min(start_pos + mlen - 1, m); i++) { int c = str[i] - 'A'; if (ch[p][c] == 0) ch[p][c] = ++ptot; p = ch[p][c]; if (!tag[p]) tag[p] = true, calc[sc][c][i - start_pos + 1]++; } } bool check(ll mid) { matrix ret = trans ^ mid; bool flag = true; for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) flag &= ret[i][j] >= n; return flag; } int main() { scanf("%lld%s", &n, str + 1), m = strlen(str + 1), mlen = min(11, m + 1); for (int i = 1; i <= m; i++) insert(i); for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) for (int k = mlen; k >= 2; k--) if (calc[i][j][k] != fpow(4, k - 2)) trans[i][j] = k - 1; ll l = 1, r = n, res = 0; while (l <= r) { ll mid = (l + r) >> 1; if (check(mid)) r = mid - 1, res = mid; else l = mid + 1; } printf("%lld\n", res); return 0; }