【AtCoder Grand Contest 019 E】Shuffle and Swap

概率 DP,多项式快速幂。

题目大意:

给定两个 01 串 $A,B$,其中各有 $k$ 个 1。令 $a_{1..k},b_{1..k}$ 分别表示 $A,B$ 中 1 的出现位置。

现在将 $a$ 和 $b$ 随机打乱,然后对 $i=1\sim n$,依次交换 $A_{a_i}$ 和 $A_{b_i}$。求最终 $A$ 和 $B$ 相等的概率(方案数)。

题解:

对于一个 $i$,若 $A_i=B_i=0$,则其没有任何用,直接忽略。

令 $x$ 表示 $A_i=B_i=1$ 的位置 $i$ 个数,$y$ 表示 $A_i=1,B_i=0$ 的位置 $i$ 的个数(和 $A_i=0,B_i=1$ 的个数相等)。

我们把交换两个位置,看做两个位置的有向连边。则我们将会连出若干环和恰好 $y$ 条链。

对于一条链,我们的目标是将$B_i=0$ 的位置的 $A_i$ 变成 $0$,而这个 $0$ 必须从 $B_i=1,A_i=0$ 的地方按次序交换过来(因为中间其他的 $A_i$ 都是 $1$)。

所以对于 $A_i=0,B_i=1$ 的点,一定作为链的开头,而 $A_i=1,B_i=0$ 的点,一定作为链的结尾。

考虑设 $f_{i,j}$ 表示用了 $j$ 个 $A=B=1$ 的点,连了 $i$ 条链,最终交换正确的概率。其中点是无标号的。

我们枚举新链需要的 $A=B=1$ 的点的个数 $k$,则中间有 $k+1$ 条连边,有 $(k+1)!$ 种不同方式。而只有 $1$ 种方式最后能完成交换。因此可得到转移方程:

直接做是 $O(n^3)$ 的。

设 $F(x)=\sum_{i}\frac{1}{(i+1)!}$,则我们相当于在求 $F^y(x)$ 的各项系数。使用 FFT 来做多项式幂函数即可。用快速幂可以较为简单地实现,时间复杂度 $O(n\log^2 n)$。

最后的答案可以表示为:

其中 $x!$ 表示任意给 $A=B=1$ 的点安排顺序,$y!$ 表示链首和链尾的不同对应方式,$(x+y)!$ 表示所有边的排列方式,乘上概率就是方案数。

Code:

#include<cstdio>
#include<algorithm>
typedef long long LL;
const int N=32768,md=998244353;
int A,B,a[N],W[N],rev[N],lim,b[N],fac[N],iv[N],ans;
char s[N],t[N];
inline void upd(int&a){a+=a>>31&md;}
inline int pow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
    return ret;
}
void init(int n){
    int l=-1;
    for(lim=1;lim<n;lim<<=1)++l;
    for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
}
void FFT(int*b,const int&f){
    static unsigned long long a[N];
    for(int i=0;i<lim;++i)a[i]=b[rev[i]];
    for(int i=1;i<lim;i<<=1)
    for(int j=0;j<lim;j+=i<<1)
    for(int k=0;k<i;++k){
        const unsigned long long x=a[j+k],y=a[j+k+i]*W[i+k]%md;
        a[j+k]+=y,a[j+k+i]=x+md-y;
    }
    for(int i=0;i<lim;++i)b[i]=a[i]%md;
    if(!f){
        const int iv=pow(lim,md-2);
        for(int i=0;i<lim;++i)b[i]=(LL)b[i]*iv%md;
        std::reverse(b+1,b+lim);
    }
}
void pow(int*F,int*A,int b,int n){
    FFT(F,1),FFT(A,1);
    for(;b;b>>=1){
        if(b&1){
            for(int i=0;i<lim;++i)F[i]=(LL)F[i]*A[i]%md;
            FFT(F,0);
            for(int i=n;i<lim;++i)F[i]=0;
            FFT(F,1);
        }
        for(int i=0;i<lim;++i)A[i]=(LL)A[i]*A[i]%md;
        FFT(A,0);
        for(int i=n;i<lim;++i)A[i]=0;
        FFT(A,1);
    }
    FFT(F,0);
}
int main(){
    scanf("%s%s",s,t);
    for(int i=0;s[i];++i)
    if(s[i]=='1'&&t[i]=='1')++A;else
    if(s[i]=='1'&&t[i]=='0')++B;
    for(int i=1;i<N;i<<=1){
        const int w=pow(3,(md-1)/(i<<1));
        W[i]=1;
        for(int j=1;j<i;++j)W[i+j]=(LL)W[i+j-1]*w%md;
    }
    for(int i=*fac=1;i<=A+B+1;++i)fac[i]=(LL)fac[i-1]*i%md;
    iv[A+B+1]=pow(fac[A+B+1],md-2);
    for(int i=A+B;~i;--i)iv[i]=(i+1LL)*iv[i+1]%md;
    *a=1;
    for(int i=0;i<=A;++i)b[i]=iv[i+1];
    init(A+1<<1);
    pow(a,b,B,A+1);
    for(int i=0;i<=A;++i)upd(ans+=a[i]-md);
    ans=(LL)ans*fac[A]%md*fac[B]%md*fac[A+B]%md;
    printf("%d\n",ans);
    return 0;
}