【Ynoi2009】rpdq

分块,虚树,卡常

题目大意:

给定一个 $n$ 个点的无根树,结点编号 $1\sim n$,带边权。

定义树上 $x,y$ 之间的边权和为 $\mathrm{dist}(x,y)$。

有 $m$ 组询问,每次给出 $l,r$,你需要求出:

题解:

rpdq = range pair distance query

事实证明序列分块的做法是可以通过的。

对原序列进行分块,则每个询问可以分为三个部分:左边边角块,中间部分,右边边角块。考虑计算这几个部分之间的贡献以及每部分内部的贡献,相加即为答案。

先来考虑整块对其他部分的贡献。对于一个块,我们可以在 $O(n)$ 的时间复杂度内求出这个块中的点与每个点的距离和。这是一个比较简单的问题,具体做法如下:

  • 将块内的点记为关键点。
  • 令 $s_i$ 表示 $i$ 子树中的关键点个数。$s_i$ 可以 $O(n)$ dfs 得到。
  • 令 $f_i$ 表示 $i$ 到所有关键点的距离和。$f_1$ 的答案就是所有关键点到根的距离和。
  • 考虑从 $f_{fa_i}$ 得到 $f_i$,只有 $i$ 和 $fa_i$ 之间的边经过的次数发生变化。有 $s_i$ 个关键点原来经过了这条边,现在不经过。其他的关键点原来不经过,现在经过。这样就可以 $O(n)$ 转移得到所有 $f$。容易使用 dfs 实现。

我们对 $f$ 进行前缀和,就可以 $O(1)$ 回答每个询问中,这个块对其他部分的贡献了。对询问离线后,可以对每个块都枚举询问,然后加上这个块的贡献。

因此这里的时间复杂度为 $O((n+m)\sqrt n)$。

上面处理了中间部分对所有部分的贡献,接下来只要考虑两个边角之间的贡献以及两个边角内部的贡献即可。

我们可以把两个边角块放在一起,那么就是要求 $O(\sqrt n)$ 个点的两两距离和。

考虑建出虚树后,就可以像上面计算整块的贡献一样,计算每个点到其他关键点的距离和了。单次计算的复杂度为 $O(\sqrt n)$。

那么只有一个问题了:建 $m$ 个点的虚树的时间复杂度为 $O(m\log m)$。

这里的 $\log$ 主要表现在对 dfs 序的排序上面,因此可以使用基数排序来去掉 $\log$。但是 $2\times 10^5$ 的范围下,需要以 $512$ 为基数才能保证复杂度和实际效率。而 $512$ 的大小可能卡不进 Cache 导致表现较差。因此考虑其他实现方式。

解决方案也很套路:对每个块里的点预先按 dfs 序排序,然后对两个块里的点,可以 $O(\sqrt n)$ 提取出它们按照 dfs 序排好序的序列,然后再 $O(\sqrt n)$ 归并即可。

剩下的就是一个经典的用栈来建虚树的过程。需要用到 $O(n\log n)-O(1)$ 的求最近公共祖先算法以保证复杂度为 $O(\sqrt n)$(不考虑 $O(n)-O(1)$,没研究过,也并不会在这里有更好的表现)。

于是我们做到了单次 $O(\sqrt n)$ 的时间复杂度。这部分一共会执行 $O(m)$ 次,因此时间复杂度为 $O(m\sqrt n)$。

注意这里计算的时候,一个点对的贡献被计算了 $2$ 次,因此最后要除以 $2$。所以我们至少要存储 $33$ 个二进制位,需要用到 $64$ 位整形。

综上,我们的算法的时间复杂度为 $O((n+m)\sqrt n)$,空间复杂度为 $O(n\log n+m)$。


然而由于该序列分块的做法常数较大(标算用了二次离线莫队,我目前尚未想到复杂度正确的做法),虽然时间复杂度达到了题目要求,但是直接莽完交上去几乎一定会得到 $0$ 分的好成绩。

所以感谢 skip2004 给我的代码进行常数优化。

下面介绍一些我能想到的,以及我能看得懂的常数优化。

首先这个 dfs 一看就非常慢。我们考虑直接在 dfs 序上求答案。这样就把一个递归变成了两个简单的循环语句。

整块的部分只剩下一些循环语句,而边角部分却有个虚树,还要用到大量 $64$ 位整数的运算,因此瓶颈在边角贡献的计算上。

由于边角贡献中,我们只关心贡献的和,所以可以不用求出每个点到其他结点的贡献,而是求虚树上每条边对答案的贡献。这样贡献不会算重,避免了 $64$ 位整数,而且可以省下一个循环。

这样处理以后,答案可以在建虚树的同时计算,不需要在建出虚树后再计算了。这样又节省了记录最终的 dfs 序列的开销。

  • 建虚树时,预先求出原 dfs 序列中相邻两点的 LCA,这样会变快。
  • 减少栈部分的寻址。
  • 优化询问 LCA 的常数。
  • 改进区间在区间内的判断(效果不明显)。
  • IO 优化(比较常规,效果可能不明显)。
  • 调整块大小(重要)。

加上各种合理的优化后,该算法是可以通过此题的。

Code:

//rpdq  =?= range pair distance query
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5,_B=210,srz=N/_B+2;
typedef unsigned uint;
struct istream {
    static const int size = 1 << 24;
    char buf[size], *vin;
    inline istream() {
        fread(buf,1,size,stdin);
        vin = buf - 1;
    }
    inline istream& operator >> (int & x) {
        for(;isspace(*++vin););
        for(x = *vin & 15;isdigit(*++vin);) x = x * 10 + (*vin & 15);
        return * this;
    }
} Cin;
struct ostream {
    static const int size = 1 << 22;
    char buf[size], *vout;
    unsigned map[10000];
    inline ostream() {
        for(int i = 0;i < 10000;++i) {
            map[i] = i % 10 + 48 << 24 | i / 10 % 10 + 48 << 16 | i / 100 % 10 + 48 << 8 | i / 1000 + 48;
        }
        vout = buf + size;
    }
    inline ~ ostream()
    { fwrite(vout,1,buf + size - vout,stdout); }
    inline ostream& operator << (uint x) {
        for(;x >= 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
        do *--vout = x % 10 + 48; while(x /= 10);
        return * this;
    }
    inline ostream& operator << (char x) {
        *--vout = x;
        return * this;
    }
} Cout;
#define bel(x)int(uint(x-1)/_B+1)
int n,m,head[N],cnt,dfn[N],idfn[N],idx,fa[N],_sz[N],dep[N],LR[N],size[N];
uint dis[N];
uint ans[N],F[N],_d[N],G[N];
int _l[N],_r[N];
int bL[N/_B+2],bR[N/_B+2],blocks,_s[N];
struct edge{
    int to,nxt,w;
}e[N<<1];
struct _cmp{
    inline bool operator()(const int&a,const int&b){return dfn[a]<dfn[b];}
};
namespace LCA{
    int st[19][N];
    inline int _min(int x,int y){return dfn[x]<dfn[y]?x:y;}
    void init(){
        for(int i=1;i<19;++i)
        for(int j=1;j+(1<<i)-1<n;++j)
        st[i][j]=_min(st[i-1][j],st[i-1][j+(1<<(i-1))]);
    }
    inline int ask(int x,int y){
        const int lg = __lg(dfn[y] - dfn[x]);
        return _min(st[lg][dfn[x]], st[lg][dfn[y] - (1 << lg)]);
    }
}
void dfs(int now){
    idfn[dfn[now]=++idx]=now;size[now]=1;
    for(int i=head[now];i;i=e[i].nxt)if(!dfn[e[i].to]){
        dep[e[i].to]=dep[now]+1;
        dis[e[i].to]=dis[now]+e[i].w;
        LCA::st[0][idx]=now;
        dfs(e[i].to);
        fa[dfn[e[i].to]]=dfn[now],
        _d[dfn[e[i].to]]=e[i].w;
        size[now]+=size[e[i].to];
    }
}
inline int is_anc(int x, int y) {
    return uint(dfn[y] - dfn[x]) < (uint)size[x];
}
struct range {
    int l, r;
    range(int L, int R) : l(L), r(R - L) {}
    inline bool in(int x) const {
        return uint(x - l) <= (uint)r;
    }
};
uint _query(int l0,int r0,int l1,int r1){
    int l=bL[bel(l0)],r=bR[bel(l0)];
    static int A[srz*2],B[srz*2],C[srz*2],sta[srz*2],lca[srz*2];int tA=0,tB=0,tC=0,top=1;
    sta[1]=1;
    const range wwj(l0, r0), dak(l1, r1);
    for(int i=l;i<=r;++i)if(wwj.in(_s[i]))A[++tA]=_s[i];
    l=bL[bel(l1)],r=bR[bel(l1)];
    if(l1<=r1)
    for(int i=l;i<=r;++i)if(dak.in(_s[i]))B[++tB]=_s[i];

    merge(A+1,A+tA+1,B+1,B+tB+1,C+1,_cmp());
    tC=tA+tB;


    for(int i=2;i<=tC;++i)lca[i]=LCA::ask(C[i-1],C[i]);
    uint a=0,b=0;
    for(int i=1;i<=tC;++i){
        if(is_anc(sta[top],C[i]))sta[++top]=C[i];
        else{
            int t=lca[i];
            while(top>1&&dfn[sta[top-1]]>=dfn[t]){
                b+=_sz[top-1]*_sz[top]*dis[sta[top-1]];
                _sz[top-1]+=_sz[top];
                --top;
            }
            sta[top]=t;
            sta[++top]=C[i];
        }
        a+=dis[C[i]];
        _sz[top]=1;
    }
    while(top>1){
        b+=_sz[top-1]*_sz[top]*dis[sta[top-1]];
        _sz[top-1]+=_sz[top];
        --top;
    }
    return a*(tC-1)-b*2;
}
int main(){
    Cin>>n>>m;
    for(int i=1;i<n;++i){
        int u,v,w;
        Cin>>u>>v>>w;
        e[++cnt]=(edge){v,head[u],w},head[u]=cnt;
        e[++cnt]=(edge){u,head[v],w},head[v]=cnt;
    }
    dfs(dep[1]=1);
    LCA::init();
    blocks=bel(n);
    for(int i=1;i<=blocks;++i)bL[i]=bR[i-1]+1,bR[i]=i*_B;
    bR[blocks]=n;
    for(int i=1;i<=n;++i)_s[i]=i;
    for(int i=1;i<=blocks;++i)
    sort(_s+bL[i],_s+bR[i]+1,_cmp());
    for(int i=1;i<=m;++i){
        int&l=_l[i],&r=_r[i];
        Cin>>l>>r;
        LR[i]=bel(r);
        if(bel(l)==bel(r))ans[i]=_query(l,r,1,0);
        else ans[i]=_query(l,bR[bel(l)],bL[bel(r)],r);
    }

    for(int id=1;id<=blocks;++id){
        F[1]=0;
        const int L=bL[id],R=bR[id];
        uint k=_query(L,R,1,0);

        memset(_sz,0,sizeof _sz);
        for(int i=L;i<=R;++i)F[1]+=dis[i],_sz[dfn[i]]=1;

        for(int i=n;i>1;--i)_sz[fa[i]]+=_sz[i];
        for(int i=2;i<=n;++i)F[i]=F[fa[i]]+_d[i]*(R-L+1-2*_sz[i]);
        G[1]=F[1];
        for(int i=2;i<=n;++i)G[i]=F[dfn[i]]+G[i-1];

        static int _g[N/_B+2];
        for(int i=1;i<=blocks;++i)_g[i]=G[bL[i]-1];
        for(int i=1;i<=m;++i)if(_l[i]<L&&R<_r[i])
        ans[i]+=k+G[L-1]-G[_l[i]-1]+G[_r[i]]-_g[LR[i]];
    }
    for(int i=m;i>=1;--i)Cout<<'\n'<<(uint)ans[i];
    return 0;
}