Loading [MathJax]/extensions/tex2jax.js

CH1402:后缀数组题解

思路

刚开始我是写的暴力+倍增排序,全盘 WA 掉之后怀疑人生。然后看正解,二分果然很毒瘤。

先准备数据,写好字符串哈希。读者如果不清楚如何求字符串哈希的话就来看着篇文章:字符串哈希例题

然后,我们考虑针对\(sort\)函数写一个比较函数。我们可以选择在区间\([0,min\{len-a+1,len-b+1\}]\)中二分出一个相等前缀长度,那么\(check(mid)\)函数就很好写了,直接判断两者的哈希值是否相等即可。

代码

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// CH1402.cpp
#include <iostream>
#include <algorithm>
#define ull long long
using namespace std;
const int maxn = 300000 + 1000;
string str = "";
int idxs[maxn], siz;
ull hashtable[maxn], power[maxn], bitNum = 133;
bool cmp(int a, int b)
{
int range = min(str.length() - a + 1, str.length() - b + 1);
int l = 0, r = range, pl = 0;
while (l <= r)
{
int mid = (l + r) >> 1;
if (hashtable[a + mid - 1] - hashtable[a - 1] * power[mid] == hashtable[b + mid - 1] - hashtable[b - 1] * power[mid])
l = mid + 1, pl = mid;
else
r = mid - 1;
}
return str[a + pl] < str[b + pl];
}
int getLen(int a, int b)
{
int range = min(str.length() - a + 1, str.length() - b + 1);
int l = 0, r = range, pl = 0;
while (l <= r)
{
int mid = (l + r) >> 1;
if (hashtable[a + mid - 1] - hashtable[a - 1] * power[mid] == hashtable[b + mid - 1] - hashtable[b - 1] * power[mid])
l = mid + 1, pl = mid;
else
r = mid - 1;
}
return pl;
}
int main()
{
cin >> str;
siz = str.length();
str = ' ' + str, power[0] = 1;
for (int i = 1; i <= siz; i++)
hashtable[i] = hashtable[i - 1] * bitNum + str[i] - 'a' + 1,
power[i] = power[i - 1] * bitNum,
idxs[i] = i;
sort(idxs + 1, idxs + siz + 1, cmp);
for (int i = 1; i <= siz; i++)
printf("%d ", idxs[i] - 1);
printf("\n0 ");
for (int i = 2; i <= siz; i++)
printf("%d ", getLen(idxs[i], idxs[i - 1]));
return 0;
}
// CH1402.cpp #include <iostream> #include <algorithm> #define ull long long using namespace std; const int maxn = 300000 + 1000; string str = ""; int idxs[maxn], siz; ull hashtable[maxn], power[maxn], bitNum = 133; bool cmp(int a, int b) { int range = min(str.length() - a + 1, str.length() - b + 1); int l = 0, r = range, pl = 0; while (l <= r) { int mid = (l + r) >> 1; if (hashtable[a + mid - 1] - hashtable[a - 1] * power[mid] == hashtable[b + mid - 1] - hashtable[b - 1] * power[mid]) l = mid + 1, pl = mid; else r = mid - 1; } return str[a + pl] < str[b + pl]; } int getLen(int a, int b) { int range = min(str.length() - a + 1, str.length() - b + 1); int l = 0, r = range, pl = 0; while (l <= r) { int mid = (l + r) >> 1; if (hashtable[a + mid - 1] - hashtable[a - 1] * power[mid] == hashtable[b + mid - 1] - hashtable[b - 1] * power[mid]) l = mid + 1, pl = mid; else r = mid - 1; } return pl; } int main() { cin >> str; siz = str.length(); str = ' ' + str, power[0] = 1; for (int i = 1; i <= siz; i++) hashtable[i] = hashtable[i - 1] * bitNum + str[i] - 'a' + 1, power[i] = power[i - 1] * bitNum, idxs[i] = i; sort(idxs + 1, idxs + siz + 1, cmp); for (int i = 1; i <= siz; i++) printf("%d ", idxs[i] - 1); printf("\n0 "); for (int i = 2; i <= siz; i++) printf("%d ", getLen(idxs[i], idxs[i - 1])); return 0; }
// CH1402.cpp
#include <iostream>
#include <algorithm>
#define ull long long
using namespace std;

const int maxn = 300000 + 1000;
string str = "";
int idxs[maxn], siz;
ull hashtable[maxn], power[maxn], bitNum = 133;

bool cmp(int a, int b)
{
    int range = min(str.length() - a + 1, str.length() - b + 1);
    int l = 0, r = range, pl = 0;
    while (l <= r)
    {
        int mid = (l + r) >> 1;
        if (hashtable[a + mid - 1] - hashtable[a - 1] * power[mid] == hashtable[b + mid - 1] - hashtable[b - 1] * power[mid])
            l = mid + 1, pl = mid;
        else
            r = mid - 1;
    }
    return str[a + pl] < str[b + pl];
}

int getLen(int a, int b)
{
    int range = min(str.length() - a + 1, str.length() - b + 1);
    int l = 0, r = range, pl = 0;
    while (l <= r)
    {
        int mid = (l + r) >> 1;
        if (hashtable[a + mid - 1] - hashtable[a - 1] * power[mid] == hashtable[b + mid - 1] - hashtable[b - 1] * power[mid])
            l = mid + 1, pl = mid;
        else
            r = mid - 1;
    }
    return pl;
}

int main()
{
    cin >> str;
    siz = str.length();
    str = ' ' + str, power[0] = 1;
    for (int i = 1; i <= siz; i++)
        hashtable[i] = hashtable[i - 1] * bitNum + str[i] - 'a' + 1,
        power[i] = power[i - 1] * bitNum,
        idxs[i] = i;
    sort(idxs + 1, idxs + siz + 1, cmp);
    for (int i = 1; i <= siz; i++)
        printf("%d ", idxs[i] - 1);
    printf("\n0 ");
    for (int i = 2; i <= siz; i++)
        printf("%d ", getLen(idxs[i], idxs[i - 1]));
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *