【PKUWC2018】猎人杀

容斥,NTT

题目大意:

给定 $n$ 个猎人,每个猎人有仇恨度 $w_i$。每个猎人死后,会随机打死另外一个活着的猎人。

假设还活着的猎人的 $w_i$ 之和为 $W$,那么这个猎人被打死的概率为 $\frac{w_i}W$。

现在,你先按照上述规则随机打死一个猎人,问 $1$ 号猎人最后死的概率是多少。

题解:

概率计算时,分母会变化,非常麻烦。

我们考虑每次计算概率的时候,都把已经死了的猎人也考虑进去。如果打中了已经死的猎人或自己,那么当前猎人继续开枪,直到打到活着的为止。容易证明这样每轮每个猎人被打死的概率和原来是相同的。

考虑容斥。我们钦定集合 $S$ 里的人都在 $1$ 号之前被打死了,其余的人随便什么时候死。

令 $W$ 为 $\sum_{i=1}^n w_i$,$P=\sum_{i\in S} w_i$

考虑这个东西怎么计算。

由于 $\sum_{i=1}^n w_i\leq 10^5$,所以上面的式子中的 $P$ 一共只有 $10^5$ 种不同取值。

考虑对每个 $P$ 计算有多少不同的 $S$ 集合,其 $w_i$ 之和为 $P$。

简单的生成函数知识,对每个猎人构造生成函数 $1+x^{w_i}$ 即可。

由于带了容斥系数,多一个人就要变一次号,那么每个猎人的生成函数就是 $1-x^{w_i}$。

于是我们要求的就是 $\prod\limits_{i=2}^n (1-x^{w_i})$ 的各项系数。

由于 $w_i$ 总和不超过 $10^5$,多项式系数之和也不超过 $10^5$,所以对其进行分治 FFT 即可。

时间复杂度 $O((\sum w_i)\log n\log(\sum w_i))$。

Code:

#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int N=262144,md=998244353,g3=(md+1)/3;
typedef long long LL;
int w[N],n,lim,rev[N];
vector<int>s;
inline void upd(int&a){a+=a>>31&md;}
inline int pow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
    return ret;
}
void init(int n){
    int l=-1;
    for(lim=1;lim<n;lim<<=1)++l;
    for(int i=1;i<lim;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
}
void NTT(vector<int>&a,int f){
    for(int i=1;i<lim;++i)
        if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<lim;i<<=1){
        const int gi=pow(f?3:g3,(md-1)/(i<<1));
        for(int j=0;j<lim;j+=i<<1)
            for(int k=0,g=1;k<i;++k,g=(LL)g*gi%md){
                const int x=a[j+k],y=(LL)a[j+k+i]*g%md;
                upd(a[j+k]+=y-md),upd(a[j+k+i]=x-y);
            }
    }
    if(!f){
        const int iv=pow(lim,md-2);
        for(int i=0;i<lim;++i)a[i]=(LL)a[i]*iv%md;
    }
}
void solve(int l,int r,vector<int>&s){
    if(l==r){
        s.resize(w[l]+1);
        s[0]=1,s[w[l]]=md-1;
        return;
    }
    vector<int>L,R;const int mid=l+r>>1;
    solve(l,mid,L),solve(mid+1,r,R);
    const int len=L.size()+R.size()-1;
    init(len);
    L.resize(lim),R.resize(lim),s.resize(lim);
    NTT(L,1),NTT(R,1);
    for(int i=0;i<lim;++i)s[i]=(LL)L[i]*R[i]%md;
    NTT(s,0);
    s.resize(len);
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;++i)
        cin>>w[i];
    solve(2,n,s);
    int ans=0;
    for(int i=0;i<s.size();++i)
        ans=(ans+(LL)pow(w[1]+i,md-2)*w[1]%md*s[i])%md;
    cout<<ans<<endl;
    return 0;
}