【UOJ#33】【UR#2】树上GCD

点分治,根号分治

题目大意:

给定一个有 $n$ 个结点的有根树。

定义 $\mathrm{dist}(u,v)$ 表示 $u$ 和 $v$ 在树上的距离。

令 $f(u,v)=\gcd(\mathrm{dist}(u,\mathrm{LCA}(u,v)),\mathrm{dist}(v,\mathrm{LCA}(u,v)))$。

对 $i\in[1,n-1]$,求出有多少对 $u,v$ 满足 $u<v$ 且 $f(u,v)=i$。

题解:

首先恰好等于 $\gcd$ 的方案不太好求,考虑求是 $\gcd$ 的倍数的方案,最后进行一个 $O(n\log n)$ 的容斥。

显然当两个数都是 $d$ 的倍数时,它们的 $\gcd$ 才是 $d$ 的倍数。

考虑点分治。

对于当前分治中心 $x$,它的儿子分成两种:$x$ 在原树中的儿子结点和 $x$ 在原树中的父亲结点。

对于所有原树中的儿子结点所在的连通块,我们可以一起计算贡献,因为任意两个在不同连通块中的结点,它们的 $\mathrm{LCA}$ 必然为 $x$。因此我们统计出每个连通块中,每个距离的出现次数,然后在 $O(n\log n)$ 得到每个深度的倍数的出现次数,进而可以计算贡献。加上点分治,这部分的总复杂度为 $O(n\log^2 n)$。

对于父亲结点所在的连通块,我们可以找到它最上面那个结点 $c$,然后考虑 $x\sim c$ 路径上这些结点(不包括 $x$)作为 $\mathrm{LCA}$ 的时候的贡献(由于是在点分治中计算贡献,所以仅计算这个连通块和其他连通块之间的贡献)。

对于一个结点 $k$ 作为 $\mathrm{LCA}$,设它与 $x$ 距离为 $m$,我们统计出它下面的子树中的每个距离的出现次数。然后对于距离 $d$,我们就要在 $x$ 的儿子所在连通块中找到距离 $k$ 恰好为 $d$ 的倍数的结点的个数。相当于到 $x$ 距离满足模 $d$ 余 $(d-m\bmod d)\bmod d$ 的结点的个数。

考虑根号分治。预处理出 $x$ 儿子结点所在连通块中的到 $x$ 距离的出现次数。

设当前分治连通块大小为 $n$。

对于 $d\gt\sqrt n$,我们直接暴力查询满足条件的数的出现次数,单次复杂度 $O(\sqrt n)$。

对于 $d\leq \sqrt n$,我们预处理 $b_{i,j}$ 表示模 $i$ 余 $j$ 的点的出现次数。预处理复杂度 $O(n\sqrt n)$,单次查询复杂度 $O(1)$。

因此对这样一个连通块计算贡献的时间复杂度为 $O(n\log n+n\sqrt n)$。

根据主定理可以得到,算法的总时间复杂度为 $O(n\log^2 n+n\sqrt n)$。

Code:

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
const int N=2e5+5;
typedef long long LL;
int n,fa[N],head[N],cnt,mxd[N],rt,sz[N],all;
LL g[N],dt[N];
bool vis[N];
struct edge{
    int to,nxt;
}e[N<<1];
void addedge(int u,int v){
    e[++cnt]=(edge){v,head[u]},head[u]=cnt;
    e[++cnt]=(edge){u,head[v]},head[v]=cnt;
}
void get_centroid(int now,int pre){
    sz[now]=1,mxd[now]=0;
    for(int i=head[now];i;i=e[i].nxt)if(e[i].to!=pre&&!vis[e[i].to]){
        get_centroid(e[i].to,now);
        sz[now]+=sz[e[i].to];
        mxd[now]=max(mxd[now],sz[e[i].to]);
    }
    mxd[now]=max(mxd[now],all-sz[now]);
    if(!rt||mxd[rt]>mxd[now])rt=now;
}
void dfs(int now,int pre,int depth,vector<int>&ct){
    if(ct.size()<depth+1)ct.resize(depth+1);
    ++ct[depth];
    for(int i=head[now];i;i=e[i].nxt)if(e[i].to!=pre&&!vis[e[i].to])
    dfs(e[i].to,now,depth+1,ct);
}
void work_simple(int now){
    static vector<int>tot,cnt;
    for(int i=head[now];i;i=e[i].nxt)if(fa[e[i].to]==now&&!vis[e[i].to]){
        const int to=e[i].to;
        dfs(to,now,1,tot);
        for(int i=1,lm=(int)tot.size();i<lm;++i)
        for(int j=i*2;j<lm;j+=i)tot[i]+=tot[j];
        if(cnt.size()<tot.size())
        cnt.resize(tot.size());
        for(int i=1,lm=(int)tot.size();i<lm;++i)
        cnt[i]+=tot[i],g[i]-=(LL)tot[i]*tot[i];
        tot.clear();
    }
    for(int i=1;i<(int)cnt.size();++i)g[i]+=(LL)cnt[i]*cnt[i];
    cnt.clear();
}
int bs[535][535],dis[N],B;
void dfs0(int now,int pre,int d=0){
    ++dis[d];
    for(int i=1;i<=B;++i)++bs[i][d%i];
    for(int i=head[now];i;i=e[i].nxt)if(e[i].to!=pre&&!vis[e[i].to])
    dfs0(e[i].to,now,d+1);
}
int query(int x,int d){
    if(d<=B)return bs[d][x];
    int res=0;
    for(int i=x;i<=all;i+=d)res+=dis[i];
    return res;
}
void work_special(int now,int rt){
    B=(int)sqrt(all)/5+1;
    for(int i=1;i<=B;++i)
    for(int j=0;j<i;++j)bs[i][j]=0;
    for(int i=0;i<=all;++i)dis[i]=0;
    dfs0(now,fa[now]);
    int exd=0;
    static vector<int>tot;
    for(int x=now;x!=rt;x=fa[x]){
        tot.clear();
        int k=fa[x];
        ++exd;
        for(int i=head[k];i;i=e[i].nxt)if(e[i].to!=x&&fa[e[i].to]==k&&!vis[e[i].to])
        dfs(e[i].to,k,1,tot);
        for(int i=1,lm=(int)tot.size();i<lm;++i)
        for(int j=i*2;j<lm;j+=i)tot[i]+=tot[j];
        for(int i=1;i<(int)tot.size();++i)g[i]+=2LL*tot[i]*query((i-exd%i)%i,i);
    }
}
void solve(int now,int root){
    vis[now]=1;
    work_simple(now);
    if(now!=root)work_special(now,root);
    int sm=all;
    for(int i=head[now];i;i=e[i].nxt)if(!vis[e[i].to]){
        all=sz[now]>sz[e[i].to]?sz[e[i].to]:sm-sz[now];
        rt=0;get_centroid(e[i].to,now);
        solve(rt,fa[e[i].to]==now?e[i].to:root);
    }
}
void dfs2(int now,int dep){
    ++dt[1],--dt[dep+1];
    for(int i=head[now];i;i=e[i].nxt)if(fa[e[i].to]==now)
    dfs2(e[i].to,dep+1);
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=2;i<=n;++i)cin>>fa[i],addedge(i,fa[i]);
    all=n,rt=0;
    get_centroid(1,0);
    solve(rt,1);
    for(int i=n;i;--i)
    for(int j=i*2;j<=n;j+=i)g[i]-=g[j];
    dfs2(1,0);
    for(int i=1;i<n;++i)dt[i]+=dt[i-1];
    for(int i=1;i<n;++i)printf("%lld\n",dt[i]+g[i]/2);
    return 0;
}