【CF504E】Misha and LCP on Tree
二分,哈希,长链剖分
题目大意:
给定一棵树,每个点有一个字符。
每次询问给出两条有向路径,路径经过的点按次序排列会组成两个字符串。要求回答这两个字符串的最长公共前缀长度。
题解:
序列上的问题,我们可以用二分+哈希的方法,在 $O(n+m\log n)$ 的时间内解决问题。在知道两个位置的前缀或后缀哈希值的情况下,我们可以 $O(1)$ 计算中间这一段的哈希值。
现在问题搬到了树上,但对于树上的一条路径,我们仍然可以 $O(1)$ 得到它的哈希值。具体就是分别维护到根路径的正串哈希值和反串哈希值($O(n)$ 预处理),然后可以通过到根的前缀/后缀 $O(1)$ 计算一条向根路径的哈希值。我们在 LCA 处断开,计算两段哈希值然后 $O(1)$ 合并就是该路径的哈希值。
在树上求哈希值的复杂度并没有变差,所以仍然可以 $O(\log n)$ 回答一次询问。
由于我们要二分 LCP 长度,所以我们要知道对每个长度,其结尾的点的位置,相当于一个求 $k$ 级祖先的问题。这里要求 $m\log n$ 次,所以用倍增的话复杂度就会多 $\log n$。
不难想到用长链剖分做到 $O(n\log n)$ 预处理,$O(1)$ 查询 $k$ 级祖先。
这样总时间复杂度是 $O((n+m)\log n)$。
需要注意常数还有不要被卡哈希,双哈希比较稳。这里由于 $m>n$,所以使用 ST 表的 $O(n\log n)$ 预处理 $O(1)$ 查询,常数会更小。不要用倍增求 LCA 那个慢死了。
Code:
#include<cstdio>
#include<ctime>
#include<cstdlib>
#include<algorithm>
using namespace std;
typedef long long LL;
#define lg2(x)(31-__builtin_clz(x))
const int md1=1004535809,md2=167772161,N=3e5+5;
int base1,base2;
int _1[N],_2[N],n,m;
int i1[N],i2[N];
char s[N];
int head[N],cnt;
int mxd[N],toph[N],tail[N],sonh[N],len[N],fa[N],F[19][N],dep[N],idx,dfn[N];
int*up[N],*down[N];
int st[20][N<<1];
inline int Inv(int a,const int&md){
int ret=1,b=md-2;
for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
return ret;
}
struct edge{
int to,nxt;
}e[N<<1];
struct data{
int s1,s2;
int len;
inline void push_front(char c){s1=(s1+(LL)c*_1[len])%md1,s2=(s2+(LL)c*_2[len++])%md2;}
inline void push_back(char c){s1=((LL)s1*base1+c)%md1,s2=((LL)s2*base2+c)%md2,++len;}
inline data operator+(const data&rhs)const{return(data){((LL)s1*_1[rhs.len]+rhs.s1)%md1,((LL)s2*_2[rhs.len]+rhs.s2)%md2,len+rhs.len};}
inline bool operator==(const data&rhs)const{return s1==rhs.s1&&s2==rhs.s2&&len==rhs.len;}
}a[N],b[N];
void dfs(int now){
st[0][dfn[now]=++idx]=now;
sonh[now]=0,mxd[now]=dep[now];
a[now]=a[fa[now]],b[now]=b[fa[now]];
a[now].push_back(s[now]),b[now].push_front(s[now]);
for(int i=head[now];i;i=e[i].nxt)
if(!dep[e[i].to]){
dep[e[i].to]=dep[now]+1;
fa[e[i].to]=F[0][e[i].to]=now;
dfs(e[i].to);
st[0][++idx]=now;
if(mxd[e[i].to]>mxd[now])mxd[now]=mxd[e[i].to],sonh[now]=e[i].to;
}
}
void dfs2(int now){
tail[toph[now]]=now,len[toph[now]]=dep[now]-dep[toph[now]]+1;
if(sonh[now])toph[sonh[now]]=toph[now],dfs2(sonh[now]);
for(int i=head[now];i;i=e[i].nxt)
if(dep[e[i].to]>dep[now]&&e[i].to!=sonh[now])dfs2(toph[e[i].to]=e[i].to);
}
inline int kfa(int x,int k){
if(dep[x]<=k)return 0;
if(k==0)return x;
const int lg=lg2(k);
x=F[lg][x],k-=1<<lg;
if(!k)return x;
const int dlt=dep[x]-dep[toph[x]];
if(dlt>=k)
return down[toph[x]][dlt-k];
return up[toph[x]][k-dlt];
}
inline int LCA(int x,int y){
if(dfn[x]>dfn[y])swap(x,y);
x=dfn[x],y=dfn[y];
int lg=lg2(y-x+1);
int a=st[lg][x],b=st[lg][y-(1<<lg)+1];
return(dep[a]<dep[b])?a:b;
}
void init(){
for(int i=1;i<19;++i)
for(int j=1;j<=n;++j)
F[i][j]=F[i-1][F[i-1][j]];
for(int i=1;i<=n;++i)
if(toph[i]==i){
int h=len[i];
up[i]=new int[h+2];
down[i]=new int[h+2];
*up[i]=*down[i]=i;
for(int j=1,nw=fa[i];j<=h&&nw;++j,nw=fa[nw])
up[i][j]=nw;
for(int j=1,nw=sonh[i];j<=h&&nw;++j,nw=sonh[nw])
down[i][j]=nw;
}
for(int i=1;i<20;++i)
for(int j=1;j<=2*n;++j){
int x=st[i-1][j],y=st[i-1][j+(1<<i-1)];
st[i][j]=(dep[x]<dep[y])?x:y;
}
}
inline data get(int u,int v){
const data&x=b[u],&y=b[v];
return(data){(LL)(x.s1-y.s1+md1)*i1[y.len]%md1,(LL)(x.s2-y.s2+md2)*i2[y.len]%md2,x.len-y.len};
}
inline data get_(int u,int v){
const data&x=a[u],&y=a[v];
return(data){(y.s1-(LL)x.s1*_1[y.len-x.len]%md1+md1)%md1,(y.s2-(LL)x.s2*_2[y.len-x.len]%md2+md2)%md2,y.len-x.len};
}
int main(){
srand(time(0));
base1=rand()%100+200,base2=rand()%300+400;
scanf("%d%s",&n,s+1);
_1[0]=_2[0]=i1[0]=i2[0]=1;
for(int i=1;i<=n;++i)
_1[i]=(LL)_1[i-1]*base1%md1,_2[i]=(LL)_2[i-1]*base2%md2;
i1[1]=Inv(base1,md1),i2[1]=Inv(base2,md2);
for(int i=2;i<=n;++i)
i1[i]=(LL)i1[i-1]*i1[1]%md1,i2[i]=(LL)i2[i-1]*i2[1]%md2;
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
e[++cnt]=(edge){v,head[u]},head[u]=cnt;
e[++cnt]=(edge){u,head[v]},head[v]=cnt;
}
dfs(dep[1]=1),dfs2(toph[1]=1);
init();
for(scanf("%d",&m);m--;){
int u1,v1,u2,v2;
scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
int L1=LCA(u1,v1),L2=LCA(u2,v2);
const int D1=dep[u1]+dep[v1]-2*dep[L1]+1,D2=dep[u2]+dep[v2]-2*dep[L2]+1;
int l=1,r=min(D1,D2),ans=0;
while(l<=r){
const int mid=l+r>>1;
const data _x=dep[u1]-mid+1>=dep[L1]?get(u1,kfa(u1,mid)):get(u1,L1)+get_(fa[L1],kfa(v1,D1-mid));
const data _y=dep[u2]-mid+1>=dep[L2]?get(u2,kfa(u2,mid)):get(u2,L2)+get_(fa[L2],kfa(v2,D2-mid));
if(_x==_y)l=(ans=mid)+1;else r=mid-1;
}
printf("%d\n",ans);
}
return 0;
}