主要思路
其实就是一道傻逼搜索题,按照位运算的方式来进行加速。读者可以考虑先看 位运算加速 Sudoku 这篇文章,了解如何使用位运算来搞定普通数独。
其实这道题在搞定了位运算数独之后就是一道暴力搜索题,只不过要记录答案还有可能性。其实笔者一开始是觉得这道题可能是 DP 或者是贪心法,没想到暴搜竟然就这样 AC 了。
见下文代码。
代码
// CH2901.cpp #include <iostream> #include <cstdio> #define lowbit(num) (num & -num) using namespace std; const int VAL[9][9] = {{6, 6, 6, 6, 6, 6, 6, 6, 6}, {6, 7, 7, 7, 7, 7, 7, 7, 6}, {6, 7, 8, 8, 8, 8, 8, 7, 6}, {6, 7, 8, 9, 9, 9, 8, 7, 6}, {6, 7, 8, 9, 10, 9, 8, 7, 6}, {6, 7, 8, 9, 9, 9, 8, 7, 6}, {6, 7, 8, 8, 8, 8, 8, 7, 6}, {6, 7, 7, 7, 7, 7, 7, 7, 6}, {6, 6, 6, 6, 6, 6, 6, 6, 6}}; int map[81], X[81], Y[81], BOX[81], counter[(1 << 9)], slist[(1 << 9)]; int answer[81]; int sx[9], sy[9], sb[9], tmp, ans = 0; void setStat(int pos, int k) { tmp = 1 << k, sx[X[pos]] ^= tmp, sy[Y[pos]] ^= tmp, sb[BOX[pos]] ^= tmp; } int getCnt(int num) { int cnt = 0; while (num > 0) cnt++, num -= lowbit(num); return cnt; } int getFactor(int pos) { if (X[pos] == 0 || X[pos] == 8 || Y[pos] == 0 || Y[pos] == 8) return 6; if (X[pos] == 1 || X[pos] == 7 || Y[pos] == 1 || Y[pos] == 7) return 7; if (X[pos] == 2 || X[pos] == 6 || Y[pos] == 2 || Y[pos] == 6) return 8; if (X[pos] == 3 || X[pos] == 5 || Y[pos] == 3 || Y[pos] == 5) return 9; return 10; } void dfs(int remain, int prev) { if (!remain) { if (ans < prev) { ans = prev; for (int i = 0; i < 81; i++) answer[i] = map[i]; } return; } int min_val = 2e9, id; for (int i = 0; i < 81; i++) if (map[i] == 0) { int stat = sx[X[i]] & sy[Y[i]] & sb[BOX[i]]; if (stat == 0) return; if (counter[stat] < min_val) min_val = counter[stat], id = i; } int stat = sx[X[id]] & sy[Y[id]] & sb[BOX[id]]; while (stat > 0) { int digit = slist[lowbit(stat)]; setStat(id, digit); map[id] = digit + 1; dfs(remain - 1, prev + VAL[X[id]][Y[id]] * (digit + 1)); setStat(id, digit); map[id] = 0; stat -= lowbit(stat); } } int main() { for (int i = 0; i < 81; i++) scanf("%d", &map[i]); for (int i = 0; i < 81; i++) X[i] = i / 9, Y[i] = i % 9, BOX[i] = (int)(X[i] / 3) * 3 + (int)(Y[i] / 3); for (int i = 0; i < 9; i++) slist[1 << i] = i; for (int i = 0; i < (1 << 9); i++) counter[i] = getCnt(i); for (int i = 0; i < 81; i++) sx[i] = sy[i] = sb[i] = (1 << 9) - 1; int rm = 81; for (int i = 0; i < 81; i++) if (map[i] != 0) setStat(i, map[i] - 1), rm--; dfs(rm, 0); int sum = 0; for (int i = 0; i < 81; i++) sum += (answer[i]) * VAL[X[i]][Y[i]]; if (ans == 0) printf("-1"); else printf("%d", sum); return 0; }