Universal OJ:#42「清华集训 2014」Sum – 题解

主要思路

这转化也真是妙…

首先,我们把 \((-1)^b\) 进行转换。发现转换成跟数论相关:\((-1)^b = 1 – 2(b \bmod 2)\),则可以把这个式子弄一下:

\[ \begin{aligned} & \sum_{i = 1}^n 1 – 2(\lfloor i \sqrt r \rfloor \bmod 2) \\ =& n -2\sum_{i = 1}^n \lfloor i \sqrt r \rfloor – 2 \lfloor \frac{i \sqrt r}{2} \rfloor \end{aligned} \]

考虑计算 \(\sum_{i = 1}^n \lfloor i \sqrt r \rfloor\),因为 \(\sqrt r\)是个无理数还挺烦,所以我们可以试着用类欧来包装下,一开始 \(a = 1, b = 0, c = 1\):

\[ \begin{aligned} & \sum_{i = 1}^n \lfloor i \sqrt r \rfloor \\ =& \sum_{i = 1}^n \lfloor i \cdot \frac{a \sqrt r + b}{c} \rfloor \end{aligned} \]

这个的意义相当于数这个斜率的直线与 \(x\) 轴做成的下三角形中整点的个数。我们对斜率进行分类讨论。

\(k = \frac{a \sqrt r + b}{c} > 1\),则可以直接用代数:

\[ \begin{aligned} & \sum_{i = 1}^n \lfloor i \cdot \frac{a \sqrt r + b}{c} \rfloor \\ =& \sum_{i = 1}^n \lfloor i \cdot (\frac{a \sqrt r + b}{c} – \lfloor \frac{a \sqrt r + b}{c} \rfloor) \rfloor + i \lfloor \frac{a \sqrt r + b}{c} \rfloor \end{aligned} \]

如果不是,那么我们可以把这条直线斜过去(把这个直线关于 \(y = x\) 进行对称),然后用矩形面积减去下三角形的点数。新的斜率为 \(\frac{1}{\frac{a\sqrt r + b}{c}} = \frac{ac\sqrt r – bc}{a^2 r – b^2}\),然后矩形在 \(x\) 轴上的投影为 \(\frac{a\sqrt r + b}{c} \cdot n\)。万事俱备,直接类欧即可。

代码

// UOJ42.cpp
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

ll T, n, r;
double spart;

ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a % b); }

ll calc(ll a, ll b, ll c, ll n)
{
    if (n == 0)
        return 0;
    if (n == 1)
        return (a * spart + b) / c;
    ll d = gcd(gcd(a, b), c);
    a /= d, b /= d, c /= d;
    ll verdict = (a * spart + b) / c;
    if (verdict == 0)
    {
        // transform the axis system;
        ll rect_area = ll((a * spart + b) / c * n) * n;
        return rect_area - calc(a * c, -c * b, a * a * r - b * b, (a * spart + b) / c * n);
    }
    else
        return calc(a, b - c * verdict, c, n)  + verdict * n * (n + 1) / 2;
}

int main()
{
    scanf("%lld", &T);
    while (T--)
    {
        scanf("%lld%lld", &n, &r);
        spart = sqrt(r);
        if (ll(spart) * ll(spart) == r)
        {
            // calc;
            if (ll(spart) % 2 == 0)
                printf("%lld\n", n);
            else
                printf("%lld\n", n - 2 * ((n + 1) >> 1));
            continue;
        }
        printf("%lld\n", n - 2LL * calc(1, 0, 1, n) + 4LL * calc(1, 0, 2, n));
    }
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *