BZOJ2288:生日礼物题解

思路

其实这道题就是在求一个序列中最多\(M\)个(可相互重叠的)子序列的最大和。我们可以把序列中连续的正数和负数全部合并在一起:正数和正数合并,负数和负数合并。然后我把所有的正数全部加入答案之中,以他们的绝对值为关键字放入小根堆中,并且计个数,与此同时再创建双向链表结构。之后,我们可以以\(cnt>m\)作为循环条件,不停的取出队中的元素来进行处理。

在处理的时候,我们先来判断情况。如果这个点不处于边缘位置(也就是位置\(prv[pos]!=0且nxt[pos]!=tot+1\)),那么我们直接减去它的绝对值。如果它是处于边缘的正数,那么直接从答案中减去就好!但是如果它很不幸,未满足以上条件,也就是它是在边缘的负数,那么我们直接把方案数减一并进入下一次循环就好。

之后我们就来合并。我们把第\(i、i-1、i+1\)元素全部合并在一起,然后再放入堆中,因为之后如果我们再次选择它,那么就意味着我们选择了第\(i-1,i+1\)个元素。合并之后,就会出现一个问题:如果删去前驱后继,那么某一次大循环中,被删除的元素出现后将无法运算。我们只需要准备一个\(vis[]\)就好,在循环开始的时候不停的筛选,直到符合条件——也就是存在时停止。

最后当方案套数小于限定值\(M\)时,循环停止,输出答案。(当然也存在处理前方案数就小的情况,那岂不更好)。

代码

// CH1812.cpp
#include <iostream>
#include <queue>
#include <cstdio>
#include <cstring>
#include <cmath>
#define ll long long
using namespace std;
const int maxn = 100200;
ll arr[maxn], prv[maxn], nxt[maxn], N, M, merged[maxn];
bool vis[maxn];
struct node
{
    ll pid, val;
    node() {}
    node(ll pd, ll v) { pid = pd, val = v; }
    bool operator<(const node &nd) const { return abs(val) > abs(nd.val); }
};
priority_queue<node> nds;
void del(ll p) { vis[p] = false, prv[nxt[p]] = prv[p], nxt[prv[p]] = nxt[p]; }
int main()
{
    scanf("%lld%lld", &N, &M);
    for (int i = 1; i <= N; i++)
        scanf("%lld", &arr[i]);
    ll tot = 1;
    // merge them to one.
    for (int i = 1; i <= N; i++)
        if (arr[i] * merged[tot] >= 0)
            merged[tot] += arr[i];
        else
            merged[++tot] = arr[i];
    ll ans = 0, cnt = 0;
    for (int i = 1; i <= tot; i++)
    {
        nds.push(node((ll)i, merged[i]));
        if (merged[i] > 0)
            ans += merged[i], cnt++;
        prv[i] = i - 1, nxt[i] = i + 1;
    }
    memset(vis, true, sizeof(vis));
    while (cnt > M)
    {
        cnt--;
        node curt = nds.top();
        while (!vis[curt.pid])
            nds.pop(), curt = nds.top();
        nds.pop();
        int i = curt.pid;
        if (prv[i] && nxt[i] != tot + 1)
            ans -= abs(merged[i]);
        else if (merged[i] > 0)
            ans -= merged[i];
        else
        {
            cnt++;
            continue;
        }
        merged[i] += merged[prv[i]] + merged[nxt[i]];
        del(prv[i]), del(nxt[i]);
        curt.val = merged[i], nds.push(curt);
    }
    printf("%lld", ans);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *