【AtCoder Grand Contest 019 E】Shuffle and Swap
概率 DP,多项式快速幂。
题目大意:
给定两个 01
串 $A,B$,其中各有 $k$ 个 1
。令 $a_{1..k},b_{1..k}$ 分别表示 $A,B$ 中 1
的出现位置。
现在将 $a$ 和 $b$ 随机打乱,然后对 $i=1\sim n$,依次交换 $A_{a_i}$ 和 $A_{b_i}$。求最终 $A$ 和 $B$ 相等的概率(方案数)。
题解:
对于一个 $i$,若 $A_i=B_i=0$,则其没有任何用,直接忽略。
令 $x$ 表示 $A_i=B_i=1$ 的位置 $i$ 个数,$y$ 表示 $A_i=1,B_i=0$ 的位置 $i$ 的个数(和 $A_i=0,B_i=1$ 的个数相等)。
我们把交换两个位置,看做两个位置的有向连边。则我们将会连出若干环和恰好 $y$ 条链。
对于一条链,我们的目标是将$B_i=0$ 的位置的 $A_i$ 变成 $0$,而这个 $0$ 必须从 $B_i=1,A_i=0$ 的地方按次序交换过来(因为中间其他的 $A_i$ 都是 $1$)。
所以对于 $A_i=0,B_i=1$ 的点,一定作为链的开头,而 $A_i=1,B_i=0$ 的点,一定作为链的结尾。
考虑设 $f_{i,j}$ 表示用了 $j$ 个 $A=B=1$ 的点,连了 $i$ 条链,最终交换正确的概率。其中点是无标号的。
我们枚举新链需要的 $A=B=1$ 的点的个数 $k$,则中间有 $k+1$ 条连边,有 $(k+1)!$ 种不同方式。而只有 $1$ 种方式最后能完成交换。因此可得到转移方程:
直接做是 $O(n^3)$ 的。
设 $F(x)=\sum_{i}\frac{1}{(i+1)!}$,则我们相当于在求 $F^y(x)$ 的各项系数。使用 FFT 来做多项式幂函数即可。用快速幂可以较为简单地实现,时间复杂度 $O(n\log^2 n)$。
最后的答案可以表示为:
其中 $x!$ 表示任意给 $A=B=1$ 的点安排顺序,$y!$ 表示链首和链尾的不同对应方式,$(x+y)!$ 表示所有边的排列方式,乘上概率就是方案数。
Code:
#include<cstdio>
#include<algorithm>
typedef long long LL;
const int N=32768,md=998244353;
int A,B,a[N],W[N],rev[N],lim,b[N],fac[N],iv[N],ans;
char s[N],t[N];
inline void upd(int&a){a+=a>>31&md;}
inline int pow(int a,int b){
int ret=1;
for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
return ret;
}
void init(int n){
int l=-1;
for(lim=1;lim<n;lim<<=1)++l;
for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
}
void FFT(int*b,const int&f){
static unsigned long long a[N];
for(int i=0;i<lim;++i)a[i]=b[rev[i]];
for(int i=1;i<lim;i<<=1)
for(int j=0;j<lim;j+=i<<1)
for(int k=0;k<i;++k){
const unsigned long long x=a[j+k],y=a[j+k+i]*W[i+k]%md;
a[j+k]+=y,a[j+k+i]=x+md-y;
}
for(int i=0;i<lim;++i)b[i]=a[i]%md;
if(!f){
const int iv=pow(lim,md-2);
for(int i=0;i<lim;++i)b[i]=(LL)b[i]*iv%md;
std::reverse(b+1,b+lim);
}
}
void pow(int*F,int*A,int b,int n){
FFT(F,1),FFT(A,1);
for(;b;b>>=1){
if(b&1){
for(int i=0;i<lim;++i)F[i]=(LL)F[i]*A[i]%md;
FFT(F,0);
for(int i=n;i<lim;++i)F[i]=0;
FFT(F,1);
}
for(int i=0;i<lim;++i)A[i]=(LL)A[i]*A[i]%md;
FFT(A,0);
for(int i=n;i<lim;++i)A[i]=0;
FFT(A,1);
}
FFT(F,0);
}
int main(){
scanf("%s%s",s,t);
for(int i=0;s[i];++i)
if(s[i]=='1'&&t[i]=='1')++A;else
if(s[i]=='1'&&t[i]=='0')++B;
for(int i=1;i<N;i<<=1){
const int w=pow(3,(md-1)/(i<<1));
W[i]=1;
for(int j=1;j<i;++j)W[i+j]=(LL)W[i+j-1]*w%md;
}
for(int i=*fac=1;i<=A+B+1;++i)fac[i]=(LL)fac[i-1]*i%md;
iv[A+B+1]=pow(fac[A+B+1],md-2);
for(int i=A+B;~i;--i)iv[i]=(i+1LL)*iv[i+1]%md;
*a=1;
for(int i=0;i<=A;++i)b[i]=iv[i+1];
init(A+1<<1);
pow(a,b,B,A+1);
for(int i=0;i<=A;++i)upd(ans+=a[i]-md);
ans=(LL)ans*fac[A]%md*fac[B]%md*fac[A+B]%md;
printf("%d\n",ans);
return 0;
}