【AtCoder Regular Contest 086 E】Smuggling Marbles

动态规划,启发式合并。

题目大意:

给定一棵以 $0$ 为根的树,初始每个结点可能有 $0$ 或 $1$ 个石子。

然后要进行以下操作无数次:

  • 根结点如果有石子,放到自己的口袋里。
  • 将每个结点的石子移动到父亲结点处。
  • 检查每个结点,如果有大于 $1$ 个石子,则将这些石子全部拿掉(不放入口袋)。

问所有不同的初始方案最后能放入口袋的石子的总和。

题解:

可以发现不同层中的石子是不会互相影响的。

所以我们考虑按层 dp。

定义 $f_{i,0/1/2}$ 表示当前结点所在子树中,深度为 $i$ 的结点,到这层以后有 $0$ 个、$1$ 个、超过 $1$ 个石子的方案数。

每个结点会保存的状态个数为它子树中深度最大的结点的深度减去它本身的深度。

考虑合并两个子树的状态。我们采用启发式合并,将浅的合并到深的上去即可。

时间复杂度证明类似长链剖分,是 $O(n)$ 的。

Code:

#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;
#define fi first
#define se second
const int N=2e5+5,md=1e9+7;
int id[N],n,fa[N],head[N],cnt,dep[N],ans,_2[N],tot[N];
struct edge{
    int to,nxt;
}e[N];
struct info{
    int _0,_1,_2;
}q;
struct data{
    int ed;
    vector<info>dp;
}F[N];
int mx;
void merge(int&x,int&y){
    if(F[x].ed<F[y].ed)x^=y^=x^=y;
    vector<info>&dpx=F[x].dp,&dpy=F[y].dp;
    for(int _y=(int)dpy.size()-1,_x=(int)dpx.size()-1;~_y;--_y,--_x)
        q=(info){(LL)dpx[_x]._0*dpy[_y]._0%md,((LL)dpx[_x]._0*dpy[_y]._1+(LL)dpx[_x]._1*dpy[_y]._0)%md,((LL)(dpx[_x]._1+dpx[_x]._2)*(dpy[_y]._1+dpy[_y]._2)+(LL)dpx[_x]._0*dpy[_y]._2+(LL)dpx[_x]._2*dpy[_y]._0)%md},dpx[_x]=q;
    mx=max(mx,(int)dpy.size());
    dpy.clear();
}
void dfs(int now){
    F[now].ed=dep[now];
    ++tot[dep[now]];
    for(int i=head[now];i;i=e[i].nxt)
        dep[e[i].to]=dep[now]+1,dfs(e[i].to);
    mx=0;
    for(int i=head[now];i;i=e[i].nxt)
        merge(id[now],id[e[i].to]);
    vector<info>&dpp=F[id[now]].dp;
    for(int it=(int)dpp.size()-1;mx--;--it)
        (dpp[it]._0+=dpp[it]._2)%=md,dpp[it]._2=0;
    dpp.push_back((info){1,1,0});
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;++i){
        cin>>fa[i];
        e[++cnt]=(edge){i,head[fa[i]]},head[fa[i]]=cnt;
        id[i]=i;
    }
    dfs(0);
    for(int i=*_2=1;i<=n+1;++i)_2[i]=_2[i-1]*2%md;
    vector<info>&dp=F[*id].dp;
    reverse(dp.begin(),dp.end());
    for(int i=0;i<dp.size();++i)
        ans=(ans+(LL)dp[i]._1*_2[n+1-tot[i]])%md;
    cout<<ans<<'\n';
    return 0;
}