【AtCoder Regular Contest 087 F】Squirrel Migration

动态规划,容斥。

题目大意:

给定一棵 $n$ 个结点的树。

对于一个排列 $p_1,p_2,p_3,\ldots,p_n$,它的贡献为 $\sum\limits_{i=1}^n\mathrm{dist}(i,p_i)$。其中 $\mathrm{dist}(i,j)$ 表示树上 $i$ 和 $j$ 的距离。

求贡献最大的排列的个数

题解:

首先考虑怎么才能使总贡献最大。

考虑边的贡献。我们要最大化每条边的贡献。

对于一条边 $(u,v)$,它的最多的经过次数是 $\min(sz_u,sz_v)$,其中 $sz_u$ 和 $sz_v$ 分别是以 $v$ 为根时 $u$ 子树大小,和以 $u$ 为根时 $v$ 子树大小。

考虑能否达到这个界。我们求出这棵树的重心,如果有两个重心,则分别在两边来往即可。如果只有一个重心,则每次都跨过重心即可。

接着考虑方案数。对于有两个重心的方案,它们只能在两边来往,每次到的点可以任意选择,所以总方案数为 $((\frac n 2)!)^2$。

对于只有一个重心的方案,我们对每棵子树分别考虑。对于在子树 $i$ 里的点,它对应的 $p_i$ 就不能在子树 $i$ 里了。

我们考虑容斥,即计算至少有 $k$ 个点对应的点仍在自己子树里的方案数。

令 $f_{i,j}$ 表示当前考虑到子树 $i$,已经有 $j$ 个点对应的点在自己子树内了。

转移如下:

这个式子中,组合数表示任选 $k$ 个移动到自己子树内,排列表示任选 $k$ 个作为目标结点,顺序可以调换。

最后的答案为 $\sum\limits _{i=0}^n (-1)^i\cdot f_{last,i}\cdot (n-i)!$。

这个 dp 是类似树上的背包问题,可以通过子树大小进行优化,其时间复杂度为 $O(n^2)$。

Code:

#include<cstdio>
#include<cstring>
typedef long long LL;
const int N=5005,md=1e9+7;
int n,head[N],cnt,sz[N],mxd[N],rt,fac[N],iv[N],root;
int f[2][N];
bool tag;
struct edge{
    int to,nxt;
}e[N<<1];
inline int C(int n,int m){return(LL)fac[n]*iv[m]%md*iv[n-m]%md;}
inline int P(int n,int m){return(LL)fac[n]*iv[n-m]%md;}
inline int pow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
    return ret;
}
void getrt(int now,int pre){
    sz[now]=1,mxd[now]=0;
    for(int i=head[now];i;i=e[i].nxt)
        if(e[i].to!=pre){
            getrt(e[i].to,now);
            sz[now]+=sz[e[i].to];
            if(mxd[now]<sz[e[i].to])mxd[now]=sz[e[i].to];
        }
    if(mxd[now]<n-sz[now])mxd[now]=n-sz[now];
    if(!rt||mxd[now]<mxd[rt])rt=now,tag=0;
    else if(mxd[now]==mxd[rt])tag=1;
}
int main(){
    scanf("%d",&n);
    for(int i=*fac=1;i<=n;++i)fac[i]=(LL)fac[i-1]*i%md;
    iv[n]=pow(fac[n],md-2);
    for(int i=n-1;~i;--i)iv[i]=(i+1LL)*iv[i+1]%md;
    for(int i=1;i<n;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        e[++cnt]=(edge){v,head[u]},head[u]=cnt;
        e[++cnt]=(edge){u,head[v]},head[v]=cnt;
    }
    rt=0;
    getrt(1,0);
    if(tag){
        printf("%lld\n",(LL)fac[n/2]*fac[n/2]%md);
        return 0;
    }
    root=rt,rt=0;
    getrt(root,0);
    int cur=0,lim=0;
    f[0][0]=1;
    for(int i=head[rt];i;i=e[i].nxt){
        int to=e[i].to;
        cur^=1;
        memset(f[cur],0,sizeof*f);
        lim+=sz[to];
        for(int j=lim;~j;--j)
            for(int k=0;k<=j&&k<=sz[to];++k)
                f[cur][j]=(f[cur][j]+(LL)f[cur^1][j-k]*C(sz[to],k)%md*P(sz[to],k))%md;
    }
    int ans=0;
    for(int i=0,x=1;i<=n;++i,x=md-x)
        ans=(ans+(LL)x*f[cur][i]%md*fac[n-i])%md;
    printf("%d\n",ans);
    return 0;
}