主要思路
虽然是个套路题,但是为了完全体会 min-max 反演的魅力之处,我决定写篇博客来强化一下。
我们大概推一推能发现:
\[ f_n = 2f_{n – 1} + f_{n – 2} \]
那么,我们现在就是要求:
\[ \sum_{i = 1}^n i \text{lcm}_{j = 1}^i f_j \]
如何把多个数做 LCM 的过程通过 min-max 反演进行优化呢?我们首先把每个数字看作是一个集合,那么这个集合的权就是集合中的因子的积。那么,两个数的最大公因数可以被理解为这两个集合的交。既然我们要求出最大的集合(最小公倍数)我们可以列出关系:
\[ \max(S)=\sum_{\emptyset\neq T\subseteq S}(-1)^{|T|-1}\min(T) \\ \text{lcm}_{i = 1}^n f_i = \prod_{T \subset S, T \neq \emptyset} \gcd(T)^{(-1)^{|T| – 1}} \]
然后我们知道一件事,就是 \(\gcd(f_a, f_b) = f_{\gcd(a, b)}\),然后带入进去:
\[ \text{lcm}_{i = 1}^n f_i = \prod_{T \subset S, T \neq \emptyset} f_{\gcd(T)}^{(-1)^{|T| – 1}} \]
喜闻乐见,可以考虑把指数的那个贡献计算出来:
\[ \text{lcm}_{i = 1}^n f_i = \prod_{d = 1}^n f_d^{\sum_{d | x} [x \leq n] \mu(\frac{x}{d})} \]
反过来即可:
\[ \text{lcm}_{i = 1}^n f_i = \prod_{d = 1}^n \prod_{i | d} f_i^{\mu(\frac{d}{i})} \]
代码
// BZ4833.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e6 + 200;
int T, n, mod, primes[MAX_N], tot, mu[MAX_N], f[MAX_N], finv[MAX_N], units[MAX_N];
bool vis[MAX_N];
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
void sieve()
{
mu[1] = 1;
for (int i = 2; i < MAX_N; i++)
{
if (!vis[i])
primes[++tot] = i, mu[i] = -1;
for (int j = 1; j <= tot && 1LL * i * primes[j] < MAX_N; j++)
{
vis[i * primes[j]] = true;
if (i % primes[j] == 0)
{
mu[i * primes[j]] = 0;
break;
}
mu[i * primes[j]] = -mu[i];
}
}
}
int main()
{
scanf("%d", &T), sieve();
while (T--)
{
scanf("%d%d", &n, &mod);
f[1] = 1, finv[1] = 1;
for (int i = 2; i <= n; i++)
f[i] = (2LL * f[i - 1] + f[i - 2]) % mod, finv[i] = fpow(f[i], mod - 2);
for (int i = 1; i <= n; i++)
units[i] = 1;
for (int i = 1; i <= n; i++)
for (int j = i; j <= n; j += i)
if (mu[j / i] < 0)
units[j] = 1LL * units[j] * finv[i] % mod;
else if (mu[j / i] > 0)
units[j] = 1LL * units[j] * f[i] % mod;
for (int i = 2; i <= n; i++)
units[i] = 1LL * units[i - 1] * units[i] % mod;
int ans = 0;
for (int i = 1; i <= n; i++)
ans = (0LL + ans + 1LL * units[i] * i % mod) % mod;
printf("%d\n", ans);
}
return 0;
}