【Ynoi2012】惊惶的SCOI2016/【Codeforces 1172E】Nauuo and ODT

LCT。

题目大意:

给定一棵 $n$ 个点的树,每个点有颜色,颜色为 $1\sim n$。

每次操作会改变一个点的颜色。

对一条 $u\sim v$ 的路径,定义 $c(u,v)$ 为路径上的不同颜色个数。

要求第一次操作前和每次操作后,求出下面的式子的值。

题解:

我们对每种颜色分别考虑贡献。

要维护存在颜色 $x$ 的路径个数比较麻烦,考虑维护不存在颜色 $x$ 的路径个数。

我们把颜色不是 $x$ 的点定为白点,是 $x$ 的点定为黑点。那么不存在 $x$ 的路径条数,就是所有极大白色连通块的大小的平方和。

我们考虑维护极大白色连通块的平方和,要支持修改单点的颜色。

这个类似的 trick 在Qtree6中有使用。

我们考虑使用 Link Cut Tree 来维护这个东西,并且我们维护的是同色边连通块。也就是说,在同一个连通块里,除了根的颜色是黑色以外,其它结点的颜色均为白色(为了方便,可以给整棵树的根设置一个父亲结点)。

对于翻转一个颜色的操作,相当于 LCT 上的 Link/Cut 的操作。同时对每个结点,我们要维护它的子树大小,以及它的儿子结点的子树大小平方和。再用一个全局变量维护整个森林中每棵树的大小的平方和。

注意这里树的形态是不变的,我们的连边/删边只会在一个点和它的父亲之间,所以并不需要 makeroot 操作,否则反而会破坏这个结构。

然后考虑 Link 和 Cut 的操作。

对于 Link,设当前的点为 $x$。考虑答案的变化情况,减少了 $x$ 的所有儿子子树大小的平方和,以及 $x$ 的父亲所在连通块大小的平方;增加了连通后整棵树的大小的平方。

对于 Cut,设当前的点为 $x$。相对地,答案减少了整棵树的平方,增加了 $x$ 的所有儿子子树大小的平方和,以及原子树去掉 $x$ 子树后的大小的平方。

对于一个点的儿子结点的子树大小平方和,和维护子树信息类似的,我们只维护虚儿子的信息,然后在 access 的时候修改即可。

至于如何找一个连通块最上面的白点,容易发现,我们这样维护,一棵树的根一定是黑点。那么我们在连通块内部找一个结点进行 access,然后根结点所在 splay 的右儿子就是我们要找的连通块的根。

这样我们就可以维护不存在一个颜色 $x$ 的路径总数了,用 $n^2$ 减去它就是存在这个颜色的路径条数。

我们离线,对每个颜色分别处理即可。由于修改一个结点,只会影响两个颜色的信息,所以这里总的 Link/Cut 的次数是 $O(m)$。

由于每个颜色的贡献在一个段内是相等的,所以相当于在时间上面进行区间加。用差分可以轻松解决。

故总时间复杂度 $O((n+m)\log n)$。

Code:

#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
const int N=4e5+6;
int n,m,head[N],cnt,fa[N],dep[N],c[N];
LL ans[N],Num;
bool col[N];
vector<uLL>E[N];
namespace lct{
    int ch[N][2],fa[N],sz[N],sm[N],_v[N];LL szz[N];
    bool tag[N];
    inline bool ckr(int x,int p=1){return ch[fa[x]][p]==x;}
    inline bool isroot(int x){return!(ckr(x)||ckr(x,0));}
    inline LL calc(int x){return(LL)sm[x]*sm[x];}
    inline void update(int x){
        sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+1;
        sm[x]=sm[ch[x][0]]+sm[ch[x][1]]+_v[x]+1;
    }
    inline void flip(int x){swap(ch[x][0],ch[x][1]),tag[x]^=1;}
    inline void pushdown(int x){
        if(tag[x]){
            if(ch[x][0])flip(ch[x][0]);
            if(ch[x][1])flip(ch[x][1]);
            tag[x]=0;
        }
    }
    void rotate(int x){
        int y=fa[x],z=fa[y],k=ckr(x);
        if(!isroot(y))ch[z][ckr(y)]=x;
        fa[x]=z,fa[y]=x,fa[ch[x][!k]]=y;
        ch[y][k]=ch[x][!k],ch[x][!k]=y;update(y);
    }
    void splay(int x){
        static int sta[N],top;sta[top=1]=x;
        for(int y=x;!isroot(y);sta[++top]=y=fa[y]);
        while(top)pushdown(sta[top--]);
        for(;!isroot(x);rotate(x))
        if(!isroot(fa[x]))rotate((ckr(x)^ckr(fa[x]))?x:fa[x]);
        update(x);
    }
    void access(int x){
        for(int t=0;x;ch[x][1]=t,t=x,x=fa[x]){
            splay(x);
            _v[x]-=sm[t],_v[x]+=sm[ch[x][1]];
            szz[x]-=calc(t),szz[x]+=calc(ch[x][1]);
        }
    }
    void makeroot(int x){access(x),splay(x),flip(x);}
    int findroot(int x){access(x);for(splay(x);ch[x][0];x=ch[x][0]);splay(x);return x;}
    void link(int x,int y){
        splay(x);
        Num-=szz[x]+calc(ch[x][1]);
        access(y),splay(y);
        int z=findroot(y);splay(z);
        Num-=calc(ch[z][1]);
        splay(y),fa[x]=y;
        _v[y]+=sm[x],szz[y]+=calc(x),update(y);
        access(x),splay(z);
        Num+=calc(ch[z][1]);
    }
    void cut(int x,int y){
        access(x);
        Num+=szz[x];
        int z=findroot(x);splay(z);
        Num-=calc(ch[z][1]);
        splay(y),ch[y][1]=fa[x]=0,update(y);
        access(y),splay(z);
        Num+=calc(ch[z][1]);
    }
}
struct edge{
    int to,nxt;
}e[N<<1];
void dfs(int now,int pre){
    for(int i=head[now];i;i=e[i].nxt)
    if(e[i].to!=pre)fa[e[i].to]=now,dfs(e[i].to,now);
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;++i)cin>>c[i],E[c[i]].push_back(i);
    for(int i=1;i<n;++i){
        int u,v;
        cin>>u>>v;
        e[++cnt]=(edge){v,head[u]},head[u]=cnt;
        e[++cnt]=(edge){u,head[v]},head[v]=cnt;
    }
    dfs(1,0);
    for(int i=1;i<=m;++i){
        int u,x;
        cin>>u>>x;
        E[c[u]].push_back((uLL)i<<32|u);
        c[u]=x;
        E[x].push_back((uLL)i<<32|u);
    }
    fa[1]=n+1;
    for(int i=1;i<=n+1;++i)lct::sm[i]=lct::sz[i]=1;
    for(int i=1;i<=n;++i)
    lct::link(i,fa[i]);
    for(int i=1;i<=n;++i){
        LL dlt=0;
        for(uLL t:E[i]){
            int tm=t>>32,u=(int)t;
            if(col[u])lct::link(u,fa[u]);else lct::cut(u,fa[u]);
            col[u]^=1;
            ans[tm]+=(LL)n*n-Num-dlt;
            dlt=(LL)n*n-Num; 
        }
        for(int u:E[i])
        if(col[u])col[u]=0,lct::link(u,fa[u]);
    }
    for(int i=0;i<=m;++i)
    cout<<(ans[i]+=i?ans[i-1]:0)<<'\n';
    return 0;
}