【Codeforces 763D】Timofey and a flat tree

树哈希,dfs 序上差分。

题目大意:

给定一棵树,要求你确定一个根结点,使得其不同构的子树的个数最多。输出任意解。

题解:

判断两棵子树同构,一个常用的方法是树哈希。

我们先钦定 $1$ 为根,然后考虑 $x$ 作为子树的根结点时的情况:

  • 若树根为 $x$,则 $x$ 子树是整棵树。
  • 若树根在 $x$ 的子树内,则 $x$ 子树是在 $1$ 为根时的子树去掉树根所在子树,加上 $x$ 的父亲向上的子树。
  • 若树根在 $x$ 的子树外,则就是以 $1$ 为根时的 $x$ 子树本身。

考虑对每个点作为子树的根时,对每个点作为树的根的答案的贡献。

我们求树的 dfs 序,然后发现,这三种情况在 dfs 序上的贡献是连续的(第三种情况分为连续的两段)。

前两种情况的贡献,直接求以 $1$ 为根的每个子树的 hash 值即可。第三种情况做类似换根 dp 的操作即可。

我们求出子树的 hash 值,然后扔进一个数组里。最后对于 hash 值相同且贡献区间相交的,将其贡献区间合并,即可保证同构的子树不会产生重复贡献。然后用差分来解决区间加。

时间复杂度为按照 hash 值进行排序的 $O(n\log n)$。

Code:

#include<cstdio>
#include<ctime>
#include<cstdlib>
#include<vector>
#include<algorithm>
typedef unsigned long long LL;
using namespace std;
const int N=1e5+5;
int n,head[N],cnt,base1,base2,base3,dfn[N],idx,idfn[N],sz[N],fa[N];
int d[N];
LL hx[N],s[N];
struct edge{
    int to,nxt;
}e[N<<1];
struct data{
    LL val;
    int sz,l,r;
    inline bool operator<(const data&rhs)const{return val==rhs.val?(sz==rhs.sz?l<rhs.l:sz<rhs.sz):val<rhs.val;}
};
vector<data>vec;
void dfs(int now){
    hx[now]=0;
    idfn[dfn[now]=++idx]=now;
    sz[now]=1;
    LL sum=0;
    for(int i=head[now];i;i=e[i].nxt)if(!dfn[e[i].to]){
        fa[e[i].to]=now;
        dfs(e[i].to);
        sz[now]+=sz[e[i].to];
        hx[now]^=hx[e[i].to]*base3+base1;
        sum+=hx[e[i].to];
    }
    hx[now]^=(LL)sz[now]*base2+1;
    hx[now]+=sum;s[now]=sum;
    int l=dfn[now],r=dfn[now]+sz[now];
    vec.push_back((data){hx[now],sz[now],1,l-1});
    vec.push_back((data){hx[now],sz[now],r,n});
}
void dfs2(int now){
    LL vx=hx[now],nx=(hx[now]-s[now])^((LL)sz[now]*base2+1);
    for(int i=head[now];i;i=e[i].nxt)
    if(dfn[now]<dfn[e[i].to]){
        hx[now]=nx^(hx[e[i].to]*base3+base1);
        if(fa[now])hx[now]^=hx[fa[now]]*base3+base1;
        hx[now]^=(LL)(n-sz[e[i].to])*base2+1;
        hx[now]+=s[now]-hx[e[i].to]+hx[fa[now]];
        vec.push_back((data){hx[now],n-sz[e[i].to],dfn[e[i].to],dfn[e[i].to]+sz[e[i].to]-1});
        dfs2(e[i].to);
    }
    hx[now]=vx;
}
int main(){
    srand(time(0));
    base1=rand()%20000+10000,base2=rand()%30000+23333,base3=rand()%1919810+114514;
    scanf("%d",&n);
    for(int i=1;i<n;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        e[++cnt]=(edge){v,head[u]},head[u]=cnt;
        e[++cnt]=(edge){u,head[v]},head[v]=cnt;
    }
    dfs(1);
    dfs2(1);
    sort(vec.begin(),vec.end());
    data nw=vec[0];
    for(int i=1;i<=vec.size();++i)
    if(i==vec.size()||nw.val!=vec[i].val||nw.sz!=vec[i].sz||nw.r<vec[i].l){
        ++d[nw.l],--d[nw.r+1];
        nw=vec[i];
    }else nw.r=vec[i].r;
    for(int i=1;i<=n;++i)d[i]+=d[i-1];
    printf("%d\n",idfn[max_element(d+1,d+n+1)-d]);
    return 0;
}