【Codeforces 773F】Test Data Generation

倍增,FFT。

题目大意:

给定 $N,M$,问满足以下条件的集合个数。

  • 集合中元素互不相同。
  • 集合中的元素个数不超过 $n$,集合不为空。
  • 集合中的元素的范围为 $[1,M]$,均为整数。
  • 设 $g$ 为集合中所有元素的最大公因数,$n$ 为集合中颜色个数,$m$ 为元素的最大值。要求 $\frac m g-n$ 的奇偶性和 $\frac m g$ 的奇偶性,$\frac m g-n$ 的奇偶性和 $m-n$ 的奇偶性都不同。

答案对给定的模数取模。

题解:

因为 $\frac m g-n$ 和 $\frac m g$ 的奇偶性不同,所以 $n$ 必是奇数。

又因为 $\frac m g-n$ 的奇偶性和 $m-n$ 的奇偶性不同,所以 $\frac m g$ 的奇偶性和 $m$ 的奇偶性不同,那么必然是 $\frac m g$ 为奇数,$m$ 为偶数。

我们考虑枚举 $g$ 中含有的因子 $2$ 的个数 $k$,那么相当于我们要求有多少个长度为奇数递增序列,它的元素范围为 $[1,\lfloor\frac M{2^k}\rfloor]$,且最大的元素为奇数

定义 $f(a,0/1)$ 表示元素大小为 $[1,a]$,最后一位为偶数/奇数的方案数的生成函数(空集的方案数为 $0$)。

考虑从 $f(a,p)$ 转移到 $f(a+1,p)$。

没有 $a+1$ 的情况就是 $f(a,p)$,而 $a+1$ 只可能是最后一个数。

单次转移可以 $O(n)$。

这个还可以直接倍增转移,考虑从 $f(a,p)$ 转移到 $f(2a,p)$。

其中 $(f(a,0)+f(a,1)+1)$ 表示选择 $\leq a$ 的部分,$f(a,p\ \mathrm{xor}\ (a\bmod 2))$ 表示选择 $\gt a$ 的部分,通过选择 $\leq a$ 然后整体 $+a$ 得到。

那么我们在求解 $a=\lfloor\frac M{2^k}\rfloor$ 时,就可以按照上述方法倍增得到其生成函数。

使用任意模数 FFT 进行转移,单次时间复杂度 $O(N\log N\log M)$。

注意到,我们在求解 $a=\lfloor\frac M 2\rfloor$ 的时候,已经计算了所有 $\lfloor\frac M{2^k}\rfloor$ 的生成函数,因此在倍增的过程中统计答案即可。

总时间复杂度 $O(N\log N\log M)$。

Code:

#include<cstdio>
#include<algorithm>
using namespace std;
const int N=65536;
typedef long long LL;
typedef unsigned long long ULL;
int n,m,md;
LL ans;
inline int pow(int a,int b,const int&md){
    int ret=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;
    return ret;
}
inline int inv(int a,const int&md){return pow(a,md-2,md);}
namespace poly{
    const int md1=998244353,md2=1004535809;
    const LL M=(LL)md1*md2;
    const int iv1_2=inv(md1,md2),iv2_1=inv(md2,md1);
    inline LL mul(LL a,LL b,const LL&md){
        const LL c=a*b-(LL)((long double)a*b/md+.5)*md;
        return(c>>63&md)+c;
    }
    int rev[N],lim;
    struct I{
        int a,b;
        I(){}
        I(int x):a(x%md1),b(x%md2){}
        I(int x,int y):a(x%md1),b(y%md2){}
        I(LL x,LL y):a(x%md1),b(y%md2){}
        I(ULL x,ULL y):a(x%md1),b(y%md2){}
        inline I operator*(const I&r)const{return I((LL)a*r.a,(LL)b*r.b);}
        inline int crt(){
            LL x=(mul((LL)a*md2,iv2_1,M)+mul((LL)b*md1,iv1_2,M))%M;
            return x%md;
        }
    }odd[N],even[N],W[N],iv,O[N],E[N],K[N],odd_[N],even_[N];
    void init_w(){
        for(int i=1;i<n*2;i<<=1){
            const I w=I(pow(3,(md1-1)/(i<<1),md1),pow(3,(md2-1)/(i<<1),md2));
            W[i]=I(1);
            for(int j=1;j<i;++j)W[i+j]=W[i+j-1]*w;
        }
    }
    void init(int n){
        int l=-1;
        for(lim=1;lim<n;lim<<=1)++l;
        for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
        iv=I(inv(lim,md1),inv(lim,md2));
    }
    void FFT(I*a,int f){
        static ULL A[N],B[N];
        for(int i=0;i<lim;++i)A[i]=a[rev[i]].a,B[i]=a[rev[i]].b;
        for(int i=1;i<lim;i<<=1)
            for(int j=0;j<lim;j+=i<<1)
                for(int k=0;k<i;++k){
                    const ULL x1=A[j+k],x2=B[j+k];
                    const ULL y1=A[j+k+i]*W[i+k].a%md1,y2=B[j+k+i]*W[i+k].b%md2;
                    A[j+k]+=y1,B[j+k]+=y2,A[j+k+i]=x1+md1-y1,B[j+k+i]=x2+md2-y2;
                }
        for(int i=0;i<lim;++i)a[i]=I(A[i],B[i]);
        if(!f){
            for(int i=0;i<lim;++i)a[i]=a[i]*iv;
            std::reverse(a+1,a+lim);
        }
    }
    void solve(int m){
        if(m==1)odd[1]=ans=1;else{
            solve(m/2);
            for(int i=0;i<lim;++i)O[i]=E[i]=K[i]=0,odd_[i]=odd[i],even_[i]=even[i];
            for(int i=1;i<n;++i)K[i]=(odd[i].a+even[i].a)%md;K[0]=1;
            FFT(K,1),FFT(odd,1),FFT(even,1);
            if(m>>1&1)for(int i=0;i<lim;++i)O[i]=K[i]*even[i],E[i]=K[i]*odd[i];
            else for(int i=0;i<lim;++i)O[i]=K[i]*odd[i],E[i]=K[i]*even[i];
            FFT(O,0),FFT(E,0);
            for(int i=0;i<n;++i)O[i]=(O[i].crt()+odd_[i].a)%md,E[i]=(E[i].crt()+even_[i].a)%md;
            for(int i=n;i<lim;++i)O[i]=E[i]=0;
            if(m&1)for(int i=n-1;i;--i)O[i]=((i==1)?O[i].a+1:O[i].a+O[i-1].a+E[i-1].a)%md;
            for(int i=0;i<n;++i)odd[i]=O[i],even[i]=E[i];
            for(int i=n;i<lim;++i)odd[i]=even[i]=0;
            for(int i=1;i<n;++++i)ans+=odd[i].a;
        }
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&md),++n;
    if(m==1)return puts("0"),0;
    poly::init_w();
    poly::init(n<<1);
    poly::solve(m/2);
    printf("%lld\n",ans%md);
    return 0;
}