主要思路
首先规定,\(N<M\)。之后写出统计答案的式子:
\[ \sum_{i=1}^N \sum_{j=1}^M lcm(i,j) \\ \sum_{i=1}^N \sum_{j=1}^M \frac{ij}{\gcd(i,j)} \]
然后我们可以试着去枚举最大公约数\(gcd(i,j)=d\),然后让\(i,j\)成为唯一的两个互质因子。
\[ \sum_{d=1}^N \sum_{i=1}^{\frac{N}{d}} \sum_{j=1}^{\frac{M}{d}} \frac{id*jd}{d}[gcd(i,j)=1] \\ \sum_{d=1}^N d \sum_{i=1}^{\frac{N}{d}} \sum_{j=1}^{\frac{M}{d}} ij \sum_{x|gcd(i,j)} \mu(x) \]
然后再次对\(gcd(i,j) = x\)进行枚举:
\[ \sum_{d=1}^N d \sum_{x=1}^{\frac{N}{d}} \mu(x) \sum_{i=1}^{\frac{N}{x}} \sum_{j=1}^{\frac{M}{x}} i x^2 j \\ \sum_{d=1}^N d \sum_{x=1}^{\frac{N}{d}} \mu(x) x^2 (\sum_{i=1}^{\frac{N}{x}} i) (\sum_{j=1}^{\frac{M}{x}} j) \]
我们可以使用整除分块来降低复杂度:处理前缀和\(\mu(x)x^2\),把\((\sum_{i=1}^{\frac{N}{x}} i) (\sum_{j=1}^{\frac{M}{x}} j)\)写成数列求和的形式:
\[ \sum_{d=1}^N d \sum_{x=1}^{\frac{N}{d}} \mu(x) x^2 \frac{(\frac{N}{x}+1)\frac{N}{x}}{2} \frac{\frac{M}{x}(\frac{M}{x}+1)}{2} \]
为了方便理解,设两个函数\(calc(a,b),g(a,b)\):
\[ g(a,b) = \frac{a(a+1)}{2} \frac{b(b+1)}{2} \\ calc(a,b) = \sum_{x=1}^{a} \mu(x) x^2 * g(\lfloor \frac{a}{x} \rfloor, \lfloor \frac{b}{x} \rfloor) \\ ans = \sum_{d=1}^N d*calc(\frac{N}{d},\frac{N}{d}) \]
这样就可以简单明了的知道那些地方需要整除分块了。
代码
// P1829.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 1e7 + 100, mod = 20101009; ll n, m, mu[MAX_N], musum[MAX_N], prime[MAX_N], tot; bool vis[MAX_N]; ll g(ll a, ll b) { return (((a * (a + 1)) / 2) % mod) * (((b * (b + 1)) / 2) % mod) % mod; } ll sum(ll a, ll b) { ll ans = 0; for (ll x = 1, y = 0; x <= a; x = y + 1) { y = min(a / (a / x), b / (b / x)); ans = (ans + (musum[y] - musum[x - 1] + mod) % mod * g(a / x, b / x) % mod) % mod; } return ans; } int main() { scanf("%lld%lld", &n, &m); if (n > m) swap(n, m); mu[1] = 1; for (ll i = 2; i < MAX_N; i++) { if (!vis[i]) prime[++tot] = i, mu[i] = -1; for (ll j = 1; j <= tot && i * prime[j] < MAX_N; j++) { vis[i * prime[j]] = true; if (i % prime[j] == 0) break; else mu[i * prime[j]] = -mu[i]; } } for (ll i = 1; i < MAX_N; i++) musum[i] = (musum[i - 1] + (mu[i] + mod) % mod * (i * i % mod) % mod) % mod; ll ans = 0; for (ll x = 1, y = 0; x <= n; x = y + 1) { y = min(n / (n / x), m / (m / x)); ans = (ans + (y - x + 1) * (y + x) / 2 % mod * sum(n / x, m / x) % mod) % mod; } printf("%lld", ans); return 0; }