Loading [MathJax]/extensions/tex2jax.js

P2257:YY 的 GCD 题解

主要思路

我们首先可以搞到一个统计答案的式子:

\[ \sum_{k \in prime} \sum_{i=1}^{n} \sum_{j=1}^{m} [gcd(i, j) = k] \\ \sum_{k \in prime} \sum_{i=1}^{\lfloor \frac{n}{k} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{k} \rfloor} [gcd(i,j)=1] \\ \sum_{k \in prime} \sum_{i=1}^{\lfloor \frac{n}{k} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{k} \rfloor} \sum_{d|gcd(i,j)} \mu(d) \\ \sum_{k=1, k \in prime}^{n} \sum_{d=1}^{\lfloor \frac{n}{k} \rfloor} \mu(d) \lfloor \frac{n}{kd} \rfloor \lfloor \frac{m}{kd} \rfloor \\ \text{设}T=kd \\ \sum_{k=1, k \in prime}^{n} \sum_{d=1}^{\lfloor \frac{n}{k} \rfloor} \mu(d) \lfloor \frac{n}{T} \rfloor \lfloor \frac{m}{T} \rfloor \\ \sum_{T=1}^{n} \lfloor \frac{n}{T} \rfloor \lfloor \frac{m}{T} \rfloor \sum_{k|T} \mu(\frac{T}{k}) \]

在这个最终的式子里,我们可以注意到可以对\(\lfloor \frac{n}{T} \rfloor \lfloor \frac{m}{T} \rfloor\)进行整数分块,将时间降到根号级别(前提是处理前缀和)。

啊呀,看代码吧。

代码

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// P2257.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 10001000;
int T, miu[MAX_N], prime[MAX_N], cnt, prefix[MAX_N], sum[MAX_N];
bool flag[MAX_N];
void sieve()
{
miu[1] = 1;
for (int i = 2; i < MAX_N; i++)
{
if (!flag[i])
prime[++cnt] = i, miu[i] = -1;
for (int j = 1; j <= cnt && i * prime[j] < MAX_N; j++)
{
flag[i * prime[j]] = true;
if (i % prime[j] == 0)
break;
miu[i * prime[j]] = -miu[i];
}
}
for (int i = 1; i <= cnt; i++)
for (int j = 1; prime[i] * j < MAX_N; j++)
prefix[j * prime[i]] += miu[j];
for (int i = 1; i < MAX_N; i++)
sum[i] = sum[i - 1] + prefix[i];
}
ll solve(int a, int b)
{
ll ans = 0;
if (a > b)
swap(a, b);
for (int l = 1, r = 0; l <= a; l = r + 1)
{
r = min(a / (a / l), b / (b / l));
ans += (ll)(sum[r] - sum[l - 1]) * (ll)(a / l) * (ll)(b / l);
}
return ans;
}
int main()
{
sieve();
scanf("%d", &T);
while (T--)
{
int n, m;
scanf("%d%d", &n, &m);
if (n > m)
swap(n, m);
printf("%lld\n", solve(n, m));
}
return 0;
}
// P2257.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 10001000; int T, miu[MAX_N], prime[MAX_N], cnt, prefix[MAX_N], sum[MAX_N]; bool flag[MAX_N]; void sieve() { miu[1] = 1; for (int i = 2; i < MAX_N; i++) { if (!flag[i]) prime[++cnt] = i, miu[i] = -1; for (int j = 1; j <= cnt && i * prime[j] < MAX_N; j++) { flag[i * prime[j]] = true; if (i % prime[j] == 0) break; miu[i * prime[j]] = -miu[i]; } } for (int i = 1; i <= cnt; i++) for (int j = 1; prime[i] * j < MAX_N; j++) prefix[j * prime[i]] += miu[j]; for (int i = 1; i < MAX_N; i++) sum[i] = sum[i - 1] + prefix[i]; } ll solve(int a, int b) { ll ans = 0; if (a > b) swap(a, b); for (int l = 1, r = 0; l <= a; l = r + 1) { r = min(a / (a / l), b / (b / l)); ans += (ll)(sum[r] - sum[l - 1]) * (ll)(a / l) * (ll)(b / l); } return ans; } int main() { sieve(); scanf("%d", &T); while (T--) { int n, m; scanf("%d%d", &n, &m); if (n > m) swap(n, m); printf("%lld\n", solve(n, m)); } return 0; }
// P2257.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 10001000;
int T, miu[MAX_N], prime[MAX_N], cnt, prefix[MAX_N], sum[MAX_N];
bool flag[MAX_N];
void sieve()
{
    miu[1] = 1;
    for (int i = 2; i < MAX_N; i++)
    {
        if (!flag[i])
            prime[++cnt] = i, miu[i] = -1;
        for (int j = 1; j <= cnt && i * prime[j] < MAX_N; j++)
        {
            flag[i * prime[j]] = true;
            if (i % prime[j] == 0)
                break;
            miu[i * prime[j]] = -miu[i];
        }
    }
    for (int i = 1; i <= cnt; i++)
        for (int j = 1; prime[i] * j < MAX_N; j++)
            prefix[j * prime[i]] += miu[j];
    for (int i = 1; i < MAX_N; i++)
        sum[i] = sum[i - 1] + prefix[i];
}
ll solve(int a, int b)
{
    ll ans = 0;
    if (a > b)
        swap(a, b);
    for (int l = 1, r = 0; l <= a; l = r + 1)
    {
        r = min(a / (a / l), b / (b / l));
        ans += (ll)(sum[r] - sum[l - 1]) * (ll)(a / l) * (ll)(b / l);
    }
    return ans;
}
int main()
{
    sieve();
    scanf("%d", &T);
    while (T--)
    {
        int n, m;
        scanf("%d%d", &n, &m);
        if (n > m)
            swap(n, m);
        printf("%lld\n", solve(n, m));
    }
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *