【CSP2019】树的重心

线段树合并,树状数组

题目大意:

给定一棵树。断掉其一条边后,该树会变成两棵树,这两棵树每棵树各有 $1$ 或 $2$ 个重心。断掉这条边的贡献为,两棵树的所有重心的标号之和。

求断掉所有边的贡献之和。

题解:

容易想到统计每棵树作为重心的次数。

对每个结点,记录其最大子树大小和次大子树大小 $sz_1,sz_2$。

对每个结点,断掉一条边相当于砍掉其一棵子树。那么有两种情况:1. 砍掉其最大子树中的一棵子树。2. 砍掉其最大子树外的一棵子树。

我们假设砍掉了 $k$ 个结点。

对于情况 2,必须满足 $\lfloor\frac{n-k}2\rfloor\geq sz1$,转化得 $n-2\times sz_1\geq k$。这个相当于求当前结点为根时,有多少子树不在当前结点的最大子树内,并且其子树大小小于等于 $n-2\times sz_1$。

对于情况 1,必须满足 $\lfloor\frac{n-k}2\rfloor \geq sz_2$ 且 $\lfloor\frac{n-k}2\rfloor\geq sz_1-k$,转化得 $2sz_1-n\leq k\leq n-2sz_2$。这个也相当于求当前结点为根时,有多少子树在当前结点的最大子树内,并且其子树大小在上述范围之间。

那么我们只需要支持查询原树的一棵子树中,有多少子树大小在某个范围之间即可。

考虑换根操作会影响哪些结点的子树大小。显然只会影响当前结点和换过去的结点的大小。

那么用一棵树状数组维护全局的所有子树大小出现次数,在换根时进行修改。

查询时,如果要查询结点上面那棵子树的信息,就用全局信息减去这个结点下面的子树的信息。

查询下面子树的信息,可以通过线段树合并实现。

时空复杂度 $O(n\log n)$。

被线性做法打爆了啊……技不如人……

Code:

#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=3e5+6;
typedef long long LL;
char buf[(int)4e7],*ss=buf;
inline int readint(){
    int d=0;
    while(!isdigit(*ss))++ss;
    while(isdigit(*ss))d=d*10+(*ss++^'0');
    return d;
}
struct edge{
    int to,nxt;
}e[N<<1];
struct info{
    int mx1,id1,mx2,id2;
}f[N];
int n,T,head[N],cnt,sz[N],dep[N],fa[N],ps[N];
LL ans=0;
int B[N];
inline void add(int i,int x){for(;i<=n;i+=i&-i)B[i]+=x;}
inline int ask(int i){int x=0;for(;i;i&=i-1)x+=B[i];return x;}
inline void update(info&x,int sz,int id){
    if(sz>x.mx1){
        x.mx2=x.mx1,x.id2=x.id1;
        x.mx1=sz,x.id1=id;
    }else
    if(sz>x.mx2){
        x.mx2=sz,x.id2=id;
    }
}
namespace sgt{
    int rt[N],d[N*28],nod;
    int ls[N*28],rs[N*28];
    inline void init(){
        memset(rt,0,sizeof rt),nod=0;
    }
    void insert(int&o,int l,int r,const int&pos,const int&x){
        if(!o)o=++nod,ls[o]=rs[o]=d[o]=0;
        d[o]+=x;
        if(l<r){
            const int mid=l+r>>1;
            if(pos<=mid)insert(ls[o],l,mid,pos,x);else
            insert(rs[o],mid+1,r,pos,x);
        }
    }
    int merge(int ld,int rd,int l,int r){
        if(!ld||!rd)return ld|rd;
        if(l==r){
            d[ld]+=d[rd];
            return ld;
        }
        const int mid=l+r>>1;
        ls[ld]=merge(ls[ld],ls[rd],l,mid);
        rs[ld]=merge(rs[ld],rs[rd],mid+1,r);
        d[ld]=d[ls[ld]]+d[rs[ld]];
        return ld;
    }
    int query(int o,int l,int r,const int&L,const int&R){
        if(!o)return 0;
        if(L<=l&&r<=R)return d[o];
        const int mid=l+r>>1;
        int ret=0;
        if(L<=mid)ret=query(ls[o],l,mid,L,R);
        if(mid<R)ret+=query(rs[o],mid+1,r,L,R);
        return ret;
    }
}
void dfs(int now){
    sz[now]=1;
    info&x=f[now];
    for(int i=head[now];i;i=e[i].nxt){
        const int to=e[i].to;
        if(!dep[to]){
            dep[to]=dep[now]+1,fa[to]=now;
            dfs(to);
            sz[now]+=sz[to];
            update(x,sz[to],to);
        }
    }
    if(now!=1)
    update(x,n-sz[now],fa[now]);
}
void dfs2(int now){
    for(int i=head[now];i;i=e[i].nxt)
    if(e[i].to!=fa[now]){
        sz[now]-=sz[e[i].to];
        add(sz[now],1);
        add(sz[e[i].to],-1);
        sz[e[i].to]+=sz[now];
        dfs2(e[i].to);
        sz[e[i].to]-=sz[now];
        add(sz[e[i].to],1);
        add(sz[now],-1);
        sz[now]+=sz[e[i].to];
    }
    LL ct=0;
    if(f[now].mx1==f[now].mx2){//最大两棵子树大小相等,任意删 
        int k=min(n-f[now].mx1*2,n-1);
        if(k>0)
        ct+=ask(k);
    }else{// 不相等,分类讨论 
        int k=min(n-f[now].mx1*2,n-1);
        if(k>0){// 删非重儿子 
            if(f[now].id1==fa[now]){
//                ct+=sgt::query(sgt::rt[now],1,n,1,k);
                for(int i=head[now];i;i=e[i].nxt)
                if(e[i].to!=fa[now])
                ct+=sgt::query(sgt::rt[e[i].to],1,n,1,k);
            }else
            ct+=(ask(k)-sgt::query(sgt::rt[f[now].id1],1,n,1,k));
        }
        int r=min(n-2*f[now].mx2,n-1),l=max(1,2*f[now].mx1-n);
        if(l<=r){//删重儿子
            if(f[now].id1!=fa[now])
            ct+=sgt::query(sgt::rt[f[now].id1],1,n,l,r);else{
                ct+=(ask(r)-ask(l-1));
                for(int i=head[now];i;i=e[i].nxt)
                if(e[i].to!=fa[now])
                ct-=sgt::query(sgt::rt[e[i].to],1,n,l,r);
//                ct-=sgt::query(sgt::rt[now],1,n,l,r);
            }
        } 
    }
    for(int i=head[now];i;i=e[i].nxt)
    if(e[i].to!=fa[now])
    sgt::rt[now]=sgt::merge(sgt::rt[now],sgt::rt[e[i].to],1,n);
    ans+=ct*now;
}
int main(){
    buf[fread(buf,1,(int)4e7-2,stdin)]='\n';fclose(stdin);
    for(T=readint();T--;){
        n=readint();
        memset(head,0,sizeof head);cnt=0;
        memset(B,0,sizeof B);
        memset(f,0,sizeof f);
        memset(fa,0,sizeof fa);
        memset(dep,0,sizeof dep);
        for(int i=1;i<n;++i){
            int u=readint(),v=readint();
            e[++cnt]=(edge){v,head[u]},head[u]=cnt;
            e[++cnt]=(edge){u,head[v]},head[v]=cnt;
        }
        dfs(dep[1]=1);
        sgt::init();
        for(int i=1;i<=n;++i)
        add(ps[i]=sz[i],1),sgt::insert(sgt::rt[i],1,n,sz[i],1);
        ans=0;
        dfs2(1);
        printf("%lld\n",ans);
    }
    return 0;
}