【PKUWC2018】猎人杀
容斥,NTT
题目大意:
给定 $n$ 个猎人,每个猎人有仇恨度 $w_i$。每个猎人死后,会随机打死另外一个活着的猎人。
假设还活着的猎人的 $w_i$ 之和为 $W$,那么这个猎人被打死的概率为 $\frac{w_i}W$。
现在,你先按照上述规则随机打死一个猎人,问 $1$ 号猎人最后死的概率是多少。
题解:
概率计算时,分母会变化,非常麻烦。
我们考虑每次计算概率的时候,都把已经死了的猎人也考虑进去。如果打中了已经死的猎人或自己,那么当前猎人继续开枪,直到打到活着的为止。容易证明这样每轮每个猎人被打死的概率和原来是相同的。
考虑容斥。我们钦定集合 $S$ 里的人都在 $1$ 号之前被打死了,其余的人随便什么时候死。
令 $W$ 为 $\sum_{i=1}^n w_i$,$P=\sum_{i\in S} w_i$
考虑这个东西怎么计算。
由于 $\sum_{i=1}^n w_i\leq 10^5$,所以上面的式子中的 $P$ 一共只有 $10^5$ 种不同取值。
考虑对每个 $P$ 计算有多少不同的 $S$ 集合,其 $w_i$ 之和为 $P$。
简单的生成函数知识,对每个猎人构造生成函数 $1+x^{w_i}$ 即可。
由于带了容斥系数,多一个人就要变一次号,那么每个猎人的生成函数就是 $1-x^{w_i}$。
于是我们要求的就是 $\prod\limits_{i=2}^n (1-x^{w_i})$ 的各项系数。
由于 $w_i$ 总和不超过 $10^5$,多项式系数之和也不超过 $10^5$,所以对其进行分治 FFT 即可。
时间复杂度 $O((\sum w_i)\log n\log(\sum w_i))$。
Code:
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int N=262144,md=998244353,g3=(md+1)/3;
typedef long long LL;
int w[N],n,lim,rev[N];
vector<int>s;
inline void upd(int&a){a+=a>>31&md;}
inline int pow(int a,int b){
int ret=1;
for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
return ret;
}
void init(int n){
int l=-1;
for(lim=1;lim<n;lim<<=1)++l;
for(int i=1;i<lim;++i)
rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
}
void NTT(vector<int>&a,int f){
for(int i=1;i<lim;++i)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<lim;i<<=1){
const int gi=pow(f?3:g3,(md-1)/(i<<1));
for(int j=0;j<lim;j+=i<<1)
for(int k=0,g=1;k<i;++k,g=(LL)g*gi%md){
const int x=a[j+k],y=(LL)a[j+k+i]*g%md;
upd(a[j+k]+=y-md),upd(a[j+k+i]=x-y);
}
}
if(!f){
const int iv=pow(lim,md-2);
for(int i=0;i<lim;++i)a[i]=(LL)a[i]*iv%md;
}
}
void solve(int l,int r,vector<int>&s){
if(l==r){
s.resize(w[l]+1);
s[0]=1,s[w[l]]=md-1;
return;
}
vector<int>L,R;const int mid=l+r>>1;
solve(l,mid,L),solve(mid+1,r,R);
const int len=L.size()+R.size()-1;
init(len);
L.resize(lim),R.resize(lim),s.resize(lim);
NTT(L,1),NTT(R,1);
for(int i=0;i<lim;++i)s[i]=(LL)L[i]*R[i]%md;
NTT(s,0);
s.resize(len);
}
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<=n;++i)
cin>>w[i];
solve(2,n,s);
int ans=0;
for(int i=0;i<s.size();++i)
ans=(ans+(LL)pow(w[1]+i,md-2)*w[1]%md*s[i])%md;
cout<<ans<<endl;
return 0;
}