【Ynoi2008】rplexq

根号分治,分块,莫队,复杂度平衡

题目大意:

给定一棵 $n$ 个结点的有根树,有 $m$ 个询问。

每次询问给出 $l,r,x$,求下列式子的值:

题解:

rplexq = range pair LCA equal x query

这题就是各种平衡复杂度。

考虑朴素的做法。对 $x$ 的每棵子树统计它里面编号在 $[l,r]$ 内的结点个数,然后计算两两乘积的和。可以通过总结点数平方减去每个儿子子树里结点平方的方式计算。

考虑对结点度数进行根号分治。

  • 对于儿子数小于等于 $\sqrt n$ 的 $x$,我们想办法对 $x$ 的每个询问都求出所有子树中编号在 $[l,r]$ 内的结点个数,然后计算答案。

    容易发现,我们相当于在做一个 $O(n)$ 个点,$O(m\sqrt n)$ 个询问的二维数点问题。

    为了平衡两边的复杂度,我们只需要使用分块前缀和的方式进行维护,这样单次修改 $O(\sqrt n)$,单次查询 $O(1)$。那么二维数点部分的总时间复杂度为 $O((n+m)\sqrt n)$。

    但是要存下所有的询问,需要 $O(m\sqrt n)$ 的空间复杂度。注意到对于一个 $x$ 的所有询问,它们所需要查询的子树都是一样的,而且 dfs 序区间连续且互不相交。因此我们数点时不具体记录每个询问,而是记录 $x$,处理时对这个 $x$ 计算答案。于是空间复杂度 $O(n+m)$。

  • 对于儿子数大于 $\sqrt n$ 的结点 $x$,这样的 $x$ 不超过 $\sqrt n$ 个。由于结点度数可能很大,因此我们不能逐个儿子进行查询。

    考虑枚举 $x$ 每个儿子 $v$,并遍历 $v$ 所在子树,将子树中的结点都标记上 $v$。然后对于每个询问,我们考虑求出不合法的方案,即每棵子树内结点的平方和。容易发现这是一个经典的问题——小 Z 的袜子,可以使用莫队算法求解。

    遍历结点的总复杂度是 $O(n\sqrt n)$ 没有问题,但是对于一个结点 $i$,莫队算法的复杂度是 $O(n_i\sqrt {m_i})$,而 $n_i$ 的总和可以达到 $n\sqrt n$,因此是错误的。

    我们希望 $n_i$ 的总和为 $n$,这样总的复杂度就为 $O(\sum n_i\sqrt{m_i})\leq O(\sum n_i\sqrt m)=O(n\sqrt m)$。

    考虑按 dfs 顺序处理询问,保证 $x$ 处理的时候,$x$ 子树中需要处理询问的结点均已处理完毕。

    在遍历结点时,我们分开记录之前的询问没遍历过的结点,和之前的询问遍历过的结点。

    对于之前询问没遍历的结点,我们将它放到序列上,对 $x$ 的询问,在这个序列上离散化后进行莫队。这样序列的总长度为 $O(n)$,则莫队的复杂度为 $O(n\sqrt m)$。

    对于之前询问遍历过的结点,我们对 $x$ 的每个儿子结点开一个数据结构,支持快速区间求和。由于我们在最开始进行了根号分治,因此存在之前询问遍历过的结点的儿子不超过 $\sqrt n$ 个,所以只需要对这些儿子开数据结构即可。

    由于每个询问要查询这 $\sqrt n$ 个儿子的信息,因此需要 $O(1)$ 的查询。想要做到 $O(1)$ 的查询则容易想到使用分块。朴素的做法对每个儿子开 $\sqrt n$ 大小的数组用于统计每个块内的结点个数,空间是可以承受的,但是还需要 $O(n)$ 大小的数组用于统计每个点的出现次数,这是无法承受的。注意到每个编号的点只有一个,因此我们可以直接开一个数组存这个点是否出现。空间复杂度 $O(n)$。

    这部分处理完以后,对分块数组进行前缀和统计,即可 $O(1)$ 查询块的区间和。

    在处理询问的时候,对于之前询问没遍历过的结点,它们的信息通过莫队算法得到。对于之前询问遍历过的结点,我们对每个儿子分别统计:块间的只需要 $O(1)$ 做块差分即可,共查询 $O(\sqrt n)$ 次,时间复杂度 $O(\sqrt n)$。边角的我们直接枚举,并对相应计数器加一,时间复杂度 $O(\sqrt n)$。然后再把这些儿子多出来的贡献加上去,时间复杂度 $O(\sqrt n)$。所以总时间复杂度 $O(m\sqrt n)$。

如果之前统计的是不合法方案,最后还需要一次二维数点来统计总方案。时间复杂度为 $O(m\log n)$。

综上,这个算法的时间复杂度为 $O(n\sqrt n+n\sqrt m+m\sqrt n)$,空间复杂度 $O(n+m)$。

Code:

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=2e5+5,_B=800;
inline int bel(int x){return(x-1)/_B+1;}
int n,m,rt,head[N],cnt,dfn[N],idfn[N],fa[N],idx,sons[N],out[N];
LL ans[N];
int presum[N],_pre[N];
struct edge{
    int to,nxt;
}e[N<<1];
int blo;
struct que{
    int l,r,id,_l,_r;
    inline bool operator<(const que&rhs)const{
        return(_l/blo!=rhs._l/blo)?_l<rhs._l:_r<rhs._r;
    }
}Q[N];
vector<que>q[N];
vector<int>_del[N],_add[N];
void dfs(int now){
    idfn[dfn[now]=++idx]=now;
    for(int i=head[now];i;i=e[i].nxt)if(!dfn[e[i].to])
    fa[e[i].to]=now,dfs(e[i].to),++sons[now];
    out[now]=idx;
}
int bL[N],bR[N],blocks;
namespace bk{
    int b[300],c[N];
    inline void add(int d){
        for(int i=bR[bel(d)];i>=d;--i)++c[i];
        for(int i=bel(d);i<=blocks;++i)++b[i];
    }
    inline int ask(int x){return c[x]+b[bel(x)-1];}
}
void solve_leq_sqrtn(){
    for(int i=1;i<=n;++i){
        bk::add(idfn[i]);
        for(int x:_add[i])
        for(const que&d:q[x]){
            const int x=bk::ask(d.r)-bk::ask(d.l-1)-_pre[d.id];
            ans[d.id]+=(LL)x*presum[d.id];
            presum[d.id]+=x,_pre[d.id]=0;
        }
        for(int x:_del[i])
        for(const que&d:q[x])
        _pre[d.id]=bk::ask(d.r)-bk::ask(d.l-1);
    }
}
int tot,_s[256][256],ct[N],BEL[N],_bel[N];
int ys[N];
bool vis_tm[N],g[N];
int A[N],len;
LL nans;
inline LL getans(int l,int r){
    static int Cnt[256];
    memset(Cnt,0,sizeof Cnt);
    const int iL=bel(l),iR=bel(r);
    LL res=0;
    for(int i=bL[iR];i<=r;++i)++Cnt[_bel[i]];
    for(int i=bL[iL];i<l;++i)--Cnt[_bel[i]];
    for(int i=1;i<=tot;++i){
        int x=Cnt[i]+_s[i][iR-1]-_s[i][iL-1];
        res+=(LL)ct[ys[i]]*x+(x-1LL)*x/2;
    }
    return res;
}
void solve_gt_sqrtn(){
    for(int _=n;_;--_){
        int _r=idfn[_];
        if(sons[_r]<=_B||!q[_r].size())continue;
        tot=len=0;
        memset(g,0,sizeof g);
        memset(ct,0,sizeof ct);
        memset(_bel,0,sizeof _bel);
        for(int i=head[_r];i;i=e[i].nxt)if(_<dfn[e[i].to]){
            int x=e[i].to,id=++tot;
            bool flag=0;ys[id]=e[i].to;
            for(int j=dfn[x];j<=out[x];++j)
            if(vis_tm[j])flag=1,++_s[id][bel(idfn[j])],_bel[idfn[j]]=id;
            else vis_tm[j]=g[idfn[j]]=1,BEL[idfn[j]]=x;
            tot-=!flag;
        }
        for(int i=1;i<=n;++i)if(g[i])A[++len]=i;
        int m=0;
        for(const que&id:q[_r]){
            Q[++m]=id;
            Q[m]._l=lower_bound(A+1,A+len+1,id.l)-A;
            Q[m]._r=upper_bound(A+1,A+len+1,id.r)-A-1;
        }
        for(int i=1;i<=len;++i)A[i]=BEL[A[i]];
        for(int i=1;i<=tot;++i)
        for(int j=2;j<=blocks;++j)_s[i][j]+=_s[i][j-1];
        blo=len/sqrt(m+1)+1;
        sort(Q+1,Q+m+1);
        nans=0;
        for(int i=1,l=1,r=0;i<=m;++i){
            while(r<Q[i]._r)nans+=ct[A[++r]]++;
            while(r>Q[i]._r)nans-=--ct[A[r--]];
            while(l>Q[i]._l)nans+=ct[A[--l]]++;
            while(l<Q[i]._l)nans-=--ct[A[l++]];
            ans[Q[i].id]=nans+getans(Q[i].l,Q[i].r);
        }
        for(int i=1;i<=tot;++i)memset(_s[i],0,sizeof*_s);
    }
}
namespace rplexq{
    struct OPTS{
        int t,l,r,id;
        inline bool operator<(const OPTS&rhs)const{return t<rhs.t;}
    }C[N*2];
    int ret[N],Tot;
    int B[N];
    inline void add(int i,int x){for(;i<=n;i+=i&-i)B[i]+=x;}
    inline int ask(int i){int x=0;for(;i;i&=i-1)x+=B[i];return x;}
    inline void insert(int l,int r,int id,int x){
        C[++Tot]=(OPTS){dfn[x],l,r,-id};
        C[++Tot]=(OPTS){out[x],l,r,id};
    }
    void work(){
        sort(C+1,C+Tot+1);
        for(int i=1,it=1;i<=n;++i){
            add(idfn[i],1);
            for(;it<=Tot&&C[it].t==i;++it)
            if(C[it].id<0)ret[-C[it].id]-=ask(C[it].r)-ask(C[it].l-1);
            else ret[C[it].id]+=ask(C[it].r)-ask(C[it].l-1);
        }
        for(int i=1;i<=m;++i)if(ret[i]){
            ans[i]=ret[i]*(ret[i]-1LL)/2-ans[i];
            if(presum[i])ans[i]+=ret[i];
        }
    }
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>m>>rt;
    blocks=bel(n);
    for(int i=1;i<=blocks;++i)bL[i]=bR[i-1]+1,bR[i]=i*_B;
    bR[blocks]=n;
    for(int i=1;i<n;++i){
        int u,v;
        cin>>u>>v;
        e[++cnt]=(edge){v,head[u]},head[u]=cnt;
        e[++cnt]=(edge){u,head[v]},head[v]=cnt;
    }
    dfs(rt);
    for(int i=1;i<=m;++i){
        int l,r,x;
        cin>>l>>r>>x;
        q[x].push_back((que){l,r,i});
        if(l<=x&&x<=r)presum[i]=1;
        if(sons[x]>_B)rplexq::insert(l,r,i,x);
    }
    for(int x=1;x<=n;++x)if(x!=rt&&sons[fa[x]]<=_B&&q[fa[x]].size()){
        _del[dfn[x]-1].push_back(fa[x]);
        _add[out[x]].push_back(fa[x]);
    }
    solve_leq_sqrtn();
    solve_gt_sqrtn();
    rplexq::work();
    for(int i=1;i<=m;++i)cout<<ans[i]<<'\n';
    return 0;
}