【Codeforces 452E】Three strings

后缀数组

题目大意:

给定三个串 $A,B,C$,令 $S(l,r)$ 表示 $S$ 的字符 $l\sim r$ 构成的子串。

对于每个 $L$ 求出 $A(a,a+L-1)=B(b,b+L-1)=C(c,c+L-1)$ 的三元组 $(a,b,c)$ 个数。

题解:

很妙的一个做法。

将三个串拼起来以后求后缀数组。

用并查集维护每个位置为开头的串的信息。

按照 $height$ 数组从大到小,每次把两个串的信息合并,合并信息的时候计算贡献,并加到当前的 $height$ 对应长度的答案里。

由于 $height$ 从大到小枚举,因此每次合并的两个连通块中的串有且仅有前 $height$ 个字符完全相等。

最后,由于我们把一个三元组的贡献加到了其最长的 $height$ 里,所以要做一遍后缀和。

时间复杂度 $O(N\log N)$。

Code:

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=3e5+9,md=1e9+7;
typedef long long LL;
char A[N],B[N],C[N];
int len,pos[N],k;
int s[N],x[N],y[N],c[N],height[N],m=127,sa[N];
int ans[N];
struct St{
    int ht,id;
    inline bool operator<(const St&rhs)const{return ht>rhs.ht;}
}d[N];
int F[N][4],fa[N];
void ssort(){
    for(int i=1;i<=len;++i)++c[x[i]=s[i]];
    for(int i=1;i<=m;++i)c[i]+=c[i-1];
    for(int i=len;i;--i)sa[c[x[i]]--]=i;
    for(int k=1;k<=len;k<<=1){
        int p=0;
        for(int i=len-k+1;i<=len;++i)
            y[++p]=i;
        for(int i=1;i<=len;++i)
            if(sa[i]>k)y[++p]=sa[i]-k;
        for(int i=0;i<=m;++i)c[i]=0;
        for(int i=1;i<=len;++i)++c[x[i]];
        for(int i=1;i<=m;++i)c[i]+=c[i-1];
        for(int i=len;i;--i)sa[c[x[y[i]]]--]=y[i];
        swap(x,y);
        x[sa[1]]=p=1;
        for(int i=2;i<=len;++i)
            x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]?p:++p);
        if(p==len)break;
        m=p;
    }
    for(int i=1,k=0;i<=len;++i)
    if(x[i]!=1){
        k-=!!k;
        int j=sa[x[i]-1];
        while(s[i+k]==s[j+k])++k;
        height[x[i]]=k;
    }
}
inline int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);}
int main(){
    scanf("%s%s%s",A,B,C);
    k=min({strlen(A),strlen(B),strlen(C)});
    for(int i=0;A[i];++i)s[++len]=A[i],pos[len]=1;
    s[++len]=++m;
    for(int i=0;B[i];++i)s[++len]=B[i],pos[len]=2;
    s[++len]=++m;
    for(int i=0;C[i];++i)s[++len]=C[i],pos[len]=3;
    ssort();
    for(int i=1;i<=len;++i)d[i]=(St){height[i],i},fa[i]=i,F[i][pos[i]]=1;
    sort(d+1,d+len+1);
    for(int i=1;i<=len;++i)if(d[i].ht>0){
        int x=find(sa[d[i].id]),y=find(sa[d[i].id-1]);
        if(x!=y){
            ans[d[i].ht]=(ans[d[i].ht]+(LL)F[x][1]*F[y][2]*F[y][3]%md+(LL)F[x][2]*F[y][1]*F[y][3]%md+
            (LL)F[x][3]*F[y][1]*F[y][2]%md+(LL)F[y][1]*F[x][2]*F[x][3]%md
            +(LL)F[y][2]*F[x][1]*F[x][3]%md+(LL)F[y][3]*F[x][1]*F[x][2]%md)%md;
            F[x][1]+=F[y][1],F[x][2]+=F[y][2],F[x][3]+=F[y][3];
            fa[y]=x;
        }
    }else break;
    for(int i=k-1;i;--i)(ans[i]+=ans[i+1])%=md;
    for(int i=1;i<=k;++i)
    printf("%d ",ans[i]);
    return 0;
}