【CTSC2018】暴力写挂
点分治,虚树
题目大意:
给定两棵 $n$ 个点的树,均以 $1$ 为根,带边权。
令 $\mathrm{depth}_i(x)$ 表示 $x$ 在第 $i$ 棵树上到根路径的边权和,$\mathrm{LCA}_i(x,y)$ 表示 $x,y$ 在第 $i$ 棵树上到根路径的边权和。
要求找到 $x,y$ 使得 $\mathrm{depth}_1(x)+\mathrm{depth}_1(y)-\mathrm{depth}_1(\mathrm{LCA}_1(x,y))-\mathrm{depth}_2(\mathrm{LCA}_2(x,y))$ 最大。
求这个最大值。
题解:
先将这个式子乘 $2$,然后可以得到 $\mathrm{dist}(x,y)+\mathrm{depth}_1(x)+\mathrm{depth}_1(y)-2\cdot\mathrm{depth}_2(\mathrm{LCA}_2(x,y))$,其中 $\mathrm{dist}(x,y)$ 表示 $x$ 和 $y$ 在第一棵树上的距离。
然后考虑对第一棵树进行点分治。令 $\mathrm{dis}(i)$ 表示点 $i$ 到当前分治中心的距离,则点 $x,y$ 的贡献即为 $\mathrm{depth}_1(x)+\mathrm{depth}_1(y)+\mathrm{dis}(x)+\mathrm{dis}(y)-2\cdot\mathrm{depth}_2(\mathrm{LCA}_2(x,y) )$。
我们令 $w_x=\mathrm{depth}_1(x)+\mathrm{dis}(x)$,则相当于要找两个点使得 $w_x+w_y-2\cdot\mathrm{depth}_2(\mathrm{LCA}_2(x,y))$ 最大。
先不考虑 $x$ 和 $y$ 在同一个子树内的情况。考虑将当前连通块的点拉出来,在第二棵树上建虚树。然后令 $f_i$ 表示以 $i$ 为根的子树中的最大 $w$ 值,树形 dp 即可。建虚树的时间复杂度为 $O(V\log V)$,则总时间复杂度 $O(n\log^2 n)$。
但是 $x$ 和 $y$ 在同一棵子树内的情况是不能计算的。如何处理?
考虑对所有子连通块进行分治,每次合并两个部分,对两部分黑白染色,然后规定一个黑点和一个白点才能产生贡献,即 $f_{i,c}$ 表示以 $i$ 为根的子树中,颜色为 $c$ 的结点的最大 $w$ 值。
注意 $x=y$ 时的贡献没有在点分治中计算,需要额外判断。
由于点分治上加一层分治,是类似于边分治的过程的,因此时间复杂度仍为 $O(n\log^2 n)$。
这个复杂度较难通过,考虑对建虚树部分进行一些优化。将求 $\mathrm{LCA}$ 部分改用 ST 表做到 $O(1)$ 查询;事先对每个连通块按 dfs 序排序,然后用归并来合并,可以在常数上面进行优化。
似乎有办法能够做到 $O(n\log n)$ 的复杂度,我的代码没能优化掉 sort,是两个 $\log$ 的,不过实际表现还不错。
Code:
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int N=366670;
typedef long long LL;
typedef vector<int>VI;
struct edge{
int to,nxt,w;
}e[N<<1];
int n,head[N],cnt,rt,sz[N],mxd[N],all;
LL dis[N],ans=-1e18,_dis[N];
bool vis[N],col[N];
namespace T{
edge e[N<<1];
int head[N],cnt,st[21][N*2],idx,dfn[N],dep[N],idfn[2*N];
int out[N];
LL ds[N];
inline void addedge(int u,int v,int w){
e[++cnt]=(edge){v,head[u],w},head[u]=cnt;
e[++cnt]=(edge){u,head[v],w},head[v]=cnt;
}
inline int _min(int x,int y){return dep[x]<dep[y]?x:y;}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline int LCA(int x,int y){
x=dfn[x],y=dfn[y];
if(x>y)swap(x,y);
const int lg=__lg(y-x+1);
return _min(st[lg][x],st[lg][y-(1<<lg)+1]);
}
void dfs(int now){
st[0][dfn[now]=++idx]=now;
idfn[idx]=now;
for(int i=head[now];i;i=e[i].nxt)if(!dep[e[i].to]){
dep[e[i].to]=dep[now]+1;
ds[e[i].to]=ds[now]+e[i].w;
dfs(e[i].to);
st[0][++idx]=now;
}
out[now]=idx;
}
void init(){
dfs(dep[1]=1);
for(int i=1;i<21;++i)
for(int j=1;j+(1<<(i-1))<=idx;++j)
st[i][j]=_min(st[i-1][j],st[i-1][j+(1<<(i-1))]);
}
}
VI id[N],tmp;int tot;
void merge(VI&s,VI&a,VI&b){
int ita=0,itb=0;
s.clear();
while(ita!=(int)a.size()&&itb!=(int)b.size())
if(T::cmp(a[ita],b[itb]))s.push_back(a[ita++]);
else s.push_back(b[itb++]);
while(ita!=(int)a.size())s.push_back(a[ita++]);
while(itb!=(int)b.size())s.push_back(b[itb++]);
}
namespace vtree{
VI sta,stb,stc;
int in[N];
LL dp[N][2];
int it;
void dfs(int now){
if(in[now]==1)dp[now][col[now]]=dis[now]+_dis[now];
++it;
while(it!=(int)stc.size()&&T::out[now]>T::dfn[stc[it]]){
int to=stc[it];
dfs(to);
ans=max(ans,dp[now][0]+dp[to][1]-2*T::ds[now]);
ans=max(ans,dp[now][1]+dp[to][0]-2*T::ds[now]);
dp[now][0]=max(dp[now][0],dp[to][0]);
dp[now][1]=max(dp[now][1],dp[to][1]);
}
}
void work(VI&V){
sta=V,stb.clear(),stc.clear();
for(int i:sta)in[i]=1;
for(int i=(int)V.size()-1;i>0;--i){
int x=T::LCA(V[i],V[i-1]);
if(!in[x])in[x]=2,stb.push_back(x);
}
sort(stb.begin(),stb.end(),T::cmp);
merge(stc,sta,stb);
for(int i:stc)dp[i][0]=dp[i][1]=-1e18;
dfs(stc[it=0]);
for(int i:stc)in[i]=0;
}
}
void solve(int l,int r,VI&vc=tmp){
if(l==r){
vc=id[l];
}else{
VI L,R;
const int mid=(l+r)>>1;
solve(l,mid,L),solve(mid+1,r,R);
for(int i:L)col[i]=0;
for(int i:R)col[i]=1;
merge(vc,L,R);
vtree::work(vc);
}
}
void dfs(int now,int pre){
for(int i=head[now];i;i=e[i].nxt)if(e[i].to!=pre&&!vis[e[i].to])
dis[e[i].to]=dis[now]+e[i].w,dfs(e[i].to,now);
}
void getrt(int now,int pre){
mxd[now]=0,sz[now]=1;
for(int i=head[now];i;i=e[i].nxt)if(e[i].to!=pre&&!vis[e[i].to]){
getrt(e[i].to,now);
sz[now]+=sz[e[i].to];
mxd[now]=max(mxd[now],sz[e[i].to]);
}
mxd[now]=max(mxd[now],all-sz[now]);
if(!rt||mxd[rt]>mxd[now])rt=now;
}
void DFS(int now,int pre,VI&vc){
vc.push_back(now);
for(int i=head[now];i;i=e[i].nxt)if(!vis[e[i].to]&&e[i].to!=pre){
_dis[e[i].to]=_dis[now]+e[i].w;
DFS(e[i].to,now,vc);
}
}
void work(int now){
_dis[now]=0;
tot=0;
for(int i=head[now];i;i=e[i].nxt)
if(!vis[e[i].to]){
_dis[e[i].to]=e[i].w,DFS(e[i].to,now,id[++tot]);
sort(id[tot].begin(),id[tot].end(),T::cmp);
}
id[++tot].push_back(now);
solve(1,tot);
for(int i=1;i<=tot;++i)id[i].clear();
}
void solve(int now){
vis[now]=1;
work(now);
int sm=all;
for(int i=head[now];i;i=e[i].nxt)if(!vis[e[i].to]){
all=sz[now]>sz[e[i].to]?sz[e[i].to]:sm-sz[now];
rt=0,getrt(e[i].to,now);
solve(rt);
}
}
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<n;++i){
int u,v,w;
cin>>u>>v>>w;
e[++cnt]=(edge){v,head[u],w},head[u]=cnt;
e[++cnt]=(edge){u,head[v],w},head[v]=cnt;
}
for(int i=1;i<n;++i){
int u,v,w;
cin>>u>>v>>w;
T::addedge(u,v,w);
}
T::init();
dfs(1,0);
all=n,rt=0,getrt(1,0);
solve(rt);
for(int i=1;i<=n;++i)ans=max(ans,2*(dis[i]-T::ds[i]));
cout<<ans/2<<'\n';
return 0;
}