Loading [MathJax]/extensions/tex2jax.js

01背包的前 k 优解问题

简述

最近准备开始复习背包,看《背包九讲》之前正好随机到了这道题,所以来做一个简单的总结。

原理

首先回忆 01 背包的 DP 递推式:

\[ dp[j] = \max\{ dp[j], dp[j – weight_i] + value_i \} \]

我们可以尝试加一维,记录为第\(x\)优解:\(dp[x][j]\)。如何从之前的更优解合并为次优解是这道题的难点。首先我们发现,对于每一个\(j\),\(\{ dp[1][j], dp[2][j], dp[3][j], \dots \}\)是单调下降的。那么,我们如果要合并出一个新的单调下降序列,我们就可以使用归并排序的合并方式:只不过不是在两个序列上,而是在两种决策中选择,我们可以把选择第\(i\)个的和不选择的答案虚拟成两个长度为\(k\)的序列,然后用归并的方式合并。

代码

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// P1858.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 5050;
int dp[55][MAX_N], k, m, n, wi[MAX_N], vi[MAX_N], tmp[MAX_N];
int main()
{
scanf("%d%d%d", &k, &m, &n);
for (int i = 1; i <= n; i++)
scanf("%d%d", &wi[i], &vi[i]);
memset(dp, 128, sizeof(dp));
dp[1][0] = 0;
for (int i = 1; i <= n; i++)
for (int j = m; j >= wi[i]; j--)
{
int ptr1 = 1, ptr2 = 1, ptr = 0;
while (ptr <= k)
if (dp[ptr1][j] > dp[ptr2][j - wi[i]] + vi[i])
tmp[++ptr] = dp[ptr1++][j];
else
tmp[++ptr] = dp[ptr2++][j - wi[i]] + vi[i];
for (int pt = 1; pt <= k; pt++)
dp[pt][j] = tmp[pt];
}
int ans = 0;
for (int i = 1; i <= k; i++)
ans += dp[i][m];
printf("%d", ans);
return 0;
}
// P1858.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 5050; int dp[55][MAX_N], k, m, n, wi[MAX_N], vi[MAX_N], tmp[MAX_N]; int main() { scanf("%d%d%d", &k, &m, &n); for (int i = 1; i <= n; i++) scanf("%d%d", &wi[i], &vi[i]); memset(dp, 128, sizeof(dp)); dp[1][0] = 0; for (int i = 1; i <= n; i++) for (int j = m; j >= wi[i]; j--) { int ptr1 = 1, ptr2 = 1, ptr = 0; while (ptr <= k) if (dp[ptr1][j] > dp[ptr2][j - wi[i]] + vi[i]) tmp[++ptr] = dp[ptr1++][j]; else tmp[++ptr] = dp[ptr2++][j - wi[i]] + vi[i]; for (int pt = 1; pt <= k; pt++) dp[pt][j] = tmp[pt]; } int ans = 0; for (int i = 1; i <= k; i++) ans += dp[i][m]; printf("%d", ans); return 0; }
// P1858.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 5050;

int dp[55][MAX_N], k, m, n, wi[MAX_N], vi[MAX_N], tmp[MAX_N];

int main()
{
    scanf("%d%d%d", &k, &m, &n);
    for (int i = 1; i <= n; i++)
        scanf("%d%d", &wi[i], &vi[i]);
    memset(dp, 128, sizeof(dp));
    dp[1][0] = 0;
    for (int i = 1; i <= n; i++)
        for (int j = m; j >= wi[i]; j--)
        {
            int ptr1 = 1, ptr2 = 1, ptr = 0;
            while (ptr <= k)
                if (dp[ptr1][j] > dp[ptr2][j - wi[i]] + vi[i])
                    tmp[++ptr] = dp[ptr1++][j];
                else
                    tmp[++ptr] = dp[ptr2++][j - wi[i]] + vi[i];
            for (int pt = 1; pt <= k; pt++)
                dp[pt][j] = tmp[pt];
        }
    int ans = 0;
    for (int i = 1; i <= k; i++)
        ans += dp[i][m];
    printf("%d", ans);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *