【洛谷P6177】Count on a tree II

bitset,简单树分块。

题目大意:

给定一棵树,每个点带点权,每次询问给出 $u,v$,求 $u$ 到 $v$ 的简单路径上的不同权值个数。强制在线。

题解:

这里需要用到一种简单的树分块技巧。简单来说就是树上撒点。

我们设一个阈值 $S$,并考虑在树上选择不超过 $\frac n S$ 个点作为关键点,满足每个关键点到离它最近的祖先关键点的距离不超过 $S$。

在树上随机选择 $\frac n S$ 个点,期望距离是正确的。

这里给出一个确定性算法,能够严格保证保证每个关键点到离它最近的祖先关键点的距离不超过 $S$。

具体地,我们每次选择一个深度最大的非关键点,然后若它的 $1\sim S$ 级祖先都不是关键点,则我们把它的 $S$ 级祖先标记为关键点。由标记过程可知距离不超过 $S$,并且每标记一个关键点,至少有 $S$ 个点不会被标记,所以关键点数量也是正确的。

接下来,我们考虑一条到根路径上的所有关键点 $x_1,x_2,x_3,\ldots,x_k$,我们用 bitset 维护它们两两间出现的颜色。处理的时候,可以用递推的方式,即 $b_{x_i,x_j}=b_{x_i,x_{j-1}}\ \mathrm{or}\ b_{x_{j-1},x_j}$。于是我们只需要用 $O(S)$ 的时间求出相邻两个关键点的 bitset,再处理出两两之间的即可。由于这样的点对最多只有 $\frac{n^2}{S^2}$ 对,所以预处理的复杂度为 $O(\frac{n^2}S+\frac{n^3}{w S^2})$。

接下来考虑求解答案。令 $t=\mathrm{LCA}(u,v)$。我们分别求出 $u,v$ 的祖先中离 $u,v$ 最近的关键点 $u_0,v_0$, 以及离 $t$ 最近且在 $t$ 子树中的关键点 $u_1,v_1$。那么一条路径被我们拆成了以下几个部分:$u\sim u_0$,$v\sim v_0$,$u_1\sim t$,$v_1\sim t$,$u_0\sim u_1$,$v_0\sim v_1$。对于前面 $4$ 种,我们直接暴力跳,时间复杂度 $O(S)$。而对于后面两种,我们已经预处理它们的 bitset,所以直接取并集即可。时间复杂度 $O(\frac n w)$。当然有可能出现 $u\sim t$,$v\sim t$ 之间没有关键点的情况,此时直接跳到 $t$ 即可。

最后求颜色个数,就调用 b.count() 即可,时间复杂度 $O(\frac n w)$。

时间复杂度 $O(\frac{n^2}S+\frac{n^3}{w S^2}+\frac{nm}{w}+mS)$。一般地,我们取 $S=\sqrt n$,得到时间复杂度 $O((n+m)\sqrt n+\frac{n^2+nm}w)$。

空间复杂度 $O(\frac{n^3}{S^2})$,一般地,我们取 $S=\sqrt n$,得到空间复杂度 $O(n^2)$,由于是 bitset,所以有较小的空间常数。

然后可以发现,当 $S$ 线性增长时,时间常数是线性增长的,而内存是平方级别下降的,所以我们可以通过调整块大小的方式来卡空间

在 $S$ 取 $1000$ 时仅用了 $11.39\texttt{MB}$,已经是大多数其他做法的约 $\frac 1 {10}$,再继续卡没有意义,而且会增加耗时。评测记录

Code:

#include<cstdio>
#include<bitset>
#include<vector>
#include<algorithm>
using namespace std;
const int N=40002;
bitset<N>bt[42][42],nw;
vector<int>vec;
int n,m,B,a[N],fa[N],dep[N],mxd[N],FF[N],sz[N],tp[N],son[N];
int id[N],cnt,head[N],tot,ans,sta[N],top,gg[N];
struct edge{
    int to,nxt;
}e[N<<1];
void dfs(int now){
    sz[now]=1,son[now]=0;
    mxd[now]=dep[now];
    for(int i=head[now];i;i=e[i].nxt)if(!dep[e[i].to]){
        dep[e[i].to]=dep[now]+1,fa[e[i].to]=now;
        dfs(e[i].to),sz[now]+=sz[e[i].to];
        if(mxd[e[i].to]>mxd[now])mxd[now]=mxd[e[i].to];
        if(!son[now]||sz[son[now]]<sz[e[i].to])son[now]=e[i].to;
    }
    if(mxd[now]-dep[now]>=1000)id[now]=++tot,mxd[now]=dep[now];
}
void dfs2(int now){
    for(int i=head[now];i;i=e[i].nxt)
    if(dep[e[i].to]>dep[now]){
        if(id[e[i].to]){
            int ip=id[sta[top]],in=id[e[i].to];
            for(int x=e[i].to;x!=sta[top];x=fa[x])bt[ip][in].set(a[x]);
            nw=bt[ip][in];
            for(int i=1;i<top;++i){
                bitset<N>&bs=bt[id[sta[i]]][in];bs=bt[id[sta[i]]][ip];bs|=nw;
            }
            FF[e[i].to]=sta[top],gg[e[i].to]=gg[sta[top]]+1;
            sta[++top]=e[i].to;
        }
        dfs2(e[i].to);
        if(id[e[i].to])--top;
    }
}
void dfs3(int now){
    if(son[now])tp[son[now]]=tp[now],dfs3(son[now]);
    for(int i=head[now];i;i=e[i].nxt)if(e[i].to!=son[now]&&dep[e[i].to]>dep[now])
    dfs3(tp[e[i].to]=e[i].to);
}
inline int LCA(int x,int y){
    while(tp[x]!=tp[y])if(dep[tp[x]]>dep[tp[y]])x=fa[tp[x]];else y=fa[tp[y]];
    return dep[x]<dep[y]?x:y;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)scanf("%d",a+i),vec.push_back(a[i]);
    sort(vec.begin(),vec.end()),vec.erase(unique(vec.begin(),vec.end()),vec.end());
    for(int i=1;i<=n;++i)a[i]=lower_bound(vec.begin(),vec.end(),a[i])-vec.begin();
    for(int i=1;i<n;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        e[++cnt]=(edge){v,head[u]},head[u]=cnt;
        e[++cnt]=(edge){u,head[v]},head[v]=cnt;
    }
    dfs(dep[1]=1);
    if(!id[1])id[1]=++tot;
    sta[top=1]=gg[1]=1,dfs2(1),dfs3(1);
    while(m--){
        int u,v;
        scanf("%d%d",&u,&v),u^=ans,nw.reset();
        int l=LCA(u,v);
        while(u!=l&&!id[u])nw.set(a[u]),u=fa[u];
        while(v!=l&&!id[v])nw.set(a[v]),v=fa[v];
        if(u!=l){
            int pre=u;
            while(dep[FF[pre]]>=dep[l])pre=FF[pre];
            if(pre!=u)nw|=bt[id[pre]][id[u]];
            while(pre!=l)nw.set(a[pre]),pre=fa[pre];
        }
        if(v!=l){
            int pre=v;
            while(dep[FF[pre]]>=dep[l])pre=FF[pre];
            if(pre!=v)nw|=bt[id[pre]][id[v]];
            while(pre!=l)nw.set(a[pre]),pre=fa[pre];
        }
        nw.set(a[l]);
        printf("%d\n",ans=nw.count());
    }
    return 0;
}