最近重温了一下双栈排序,发现了一个惊天 BUG。 常规的算法大概是这样的:
建图,构建出约束关系:
- 对于每一个位置 $i$,建立一个点 $v_i$。
- 对于 $i < j$,满足 $a_i < a_j$ 且有 $k > j$ 满足 $a_k < a_i$,则连接 $v_i\leftrightarrow v_j$ 的无向边。
对这个图进行二分图黑白染色,如果不能染色,那么说明无解。否则这种染色方案对应一个可行解:顺着考虑序列 $P$,如果点 $v_i$ 是白色,那么它入栈 $S_1$,否则入栈 $S_2$,入栈之后能出栈的全部出栈。
这样的复杂度是 $O(n^2)$ 的(我知道有更优的复杂度),足够通过 OJ 上的测试数据了,但是问题就在于「入栈之后能出栈的全部出栈」这一句话是有问题的。
我们来看一组数据:
4
2 3 1 4
网上能找到的几乎所有程序跑出来的结果是 a c a b b d a b
,然而答案 a c a b b a d b
确是一个更优的可行解。
原因就在于:「入栈之后能出栈的全部出栈」在某些情况下会导致不优,不难发现只有一种情况,那就是「先出栈 $S_2$,再入栈 $S_1$」,我们可以调整为「先入栈 $S_1$ 再出栈 $S_2$」,这样操作就由 d a
变成了 a d
,仍然合法而且字典序更小。然而数据中并没有包含这一种情况,所以,没有考虑到这一点的程序也 AC 了。(网上我只发现 __debug 的程序是对的)。
解决方案是在每次出栈的时候,先全部出栈,然后再对这一次出栈的序列进行扫描,假设这一次的序列是 b d b d d d d d
,找到最长的后缀 d
,如果长度大于 0 而且且下一个点入栈 $S_1$,那么就将后面连续的操作 d
全部撤销掉。
不难发现更改之后的算法,每个点最多出入栈 2 次,所以复杂度没有变化仍然是 $O(n^2)$ 的。
代码
// Created by Sengxian on 2016/11/04.
// Copyright (c) 2016年 Sengxian. All rights reserved.
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1000 + 3;
int n, a[MAX_N], mn[MAX_N], color[MAX_N];
bool G[MAX_N][MAX_N];
bool Dfs(int u, int c) {
color[u] = c;
for (int v = 0; v < n; ++v) if (G[u][v]) {
if (color[v] == c) return false;
if (color[v] == 0 && !Dfs(v, -c)) return false;
}
return true;
}
stack<int> A, B;
vector<char> ans;
int now = 0;
void go(int i) {
vector<int> vec;
while (true) {
bool update = false;
if (!A.empty() && A.top() == now) A.pop(), now++, vec.push_back('b'), update = true;
if (!B.empty() && B.top() == now) B.pop(), now++, vec.push_back('d'), update = true;
if (!update) break;
}
if (i + 1 < n && color[i + 1] == -1) {
for (int j = (int)vec.size() - 1; j >= 0; --j)
if (vec[j] == 'b' || (j == 0 && vec[j] == 'd')) { // 特判
int len = (int)vec.size() - j - 1;
if (j == 0 && vec[j] == 'd') len = (int)vec.size();
for (int k = now - 1; k > now - 1 - len; --k)
B.push(k);
now -= len;
vec.resize((int)vec.size() - len);
break;
}
}
ans.insert(ans.end(), vec.begin(), vec.end());
}
void solve() {
for (int i = 0; i < n; ++i)
if (color[i] == -1) {
A.push(a[i]), ans.push_back('a');
go(i);
} else if (color[i] == 1) {
B.push(a[i]), ans.push_back('c');
go(i);
}
for (int i = 0; i < (int)ans.size(); ++i)
printf("%c%c", ans[i], i + 1 == (int)ans.size() ? '\n' : ' ');
}
int main() {
#ifdef DEBUG
freopen("test.in", "r", stdin);
#endif
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", a + i), a[i]--;
mn[n] = n;
for (int i = n - 1; i >= 0; --i) mn[i] = min(mn[i + 1], a[i]);
for (int i = 0; i < n; ++i)
for (int j = i + 1; j < n; ++j) if (a[i] < a[j] && mn[j + 1] < a[i])
G[i][j] = G[j][i] = true;
for (int i = 0; i < n; ++i) if (!color[i])
if (!Dfs(i, -1)) return puts("0"), 0;
solve();
return 0;
}