【WC2020】猜数游戏

原根相关,建图缩点

题目大意:

给定一个数列 $a_1,a_2,a_3,\ldots,a_n$,以及 $p$,保证 $p$ 是一个奇质数的正整数幂。

对于这个数列的任意一个非空子序列 $b_1,b_2,b_3,\ldots,b_k$,你需要选定若干个数 $c_1,c_2,c_3,\ldots,c_t$,满足:

  • $c_i$ 在 $b$ 中出现。
  • 对于任意一个 $b_i$,都能找到一个 $c_j$,满足 $c_j^x\equiv b_i\pmod p$。
  • 要求 $t$ 最小,即选择的数最少。

求 $a$ 的所有非空子序列作为 $b$ 时的最小 $t$ 之和。

题解:

如果两个数 $a,b$ 满足 $a^k\equiv b\pmod p$,则将 $a$ 与 $b$ 之间连一条有向边。考虑每个点产生的贡献。一个点要产生贡献,则必须保证能到达它的点除了它本身都不存在集合中,假设能到达它的点(包括它本身)有 $k$ 个,则能产生 $2^{n-k}$ 的贡献。

当然图中会出现环,所以我们对其缩点后再求解。一个连通块中至少存在一个结点才能被选,所以若有 $t$ 个结点,则贡献应为 $2^{n-k}(2^t-1)$。

那么现在的问题就是如何建图。

考虑 $p$ 是质数的情况。

对于数 $a$,我们令最小的 $x$ 满足 $a^x\equiv 1\pmod p$ 为 $a$ 的,记为 $\mathrm{ord}_pa$。

先考虑把所有 $a_i$ 都变成 $g^{f_i}$ 的形式,其中 $g$ 为 $p$ 的原根。然后 $a_x^k\equiv a_y\pmod p$,相当于 $g^{f_x k}\equiv g^{f_y}\pmod p$。根据欧拉定理,我们相当于要求方程 $f_x\cdot k\equiv f_y\pmod{\varphi(p)}$ 有正整数解。

这个方程有解的充要条件是 $\gcd(f_x,\varphi(p))\ |\ f_y$,进一步转化后可以变为 $\gcd(f_x,\varphi(p))\ | \ \gcd(f_y,\varphi(p))$。

这个转化有什么用呢?我们考虑 $a\equiv g^n\pmod p$ 以及 $a^x\equiv 1\pmod p$,则 $g^{nx}\equiv 1\pmod p$,又根据欧拉定理,$g^{\varphi(p)}\equiv 1\pmod p$,所以 $\varphi(p)\ |\ nx$。又因为 $x$ 要最小,所以 $x=\frac{\varphi(p)}{\gcd(n,\varphi(p))}$。

即 $\mathrm{ord}_p a_x=\frac{\varphi(p)}{\gcd(f_x,\varphi(p))}$。因此 $\gcd(f_x,\varphi(p))\ | \ \gcd(f_y,\varphi(p))$ 等价于 $\mathrm{ord}_p a_y\ |\ \mathrm{ord}_p a_x$。

因此我们只需要求出每个数的阶就可以快速判断两个点直接是否有边了。

至于如何求 $a$ 的阶,我们采用试除法,即枚举 $\varphi(p)$ 的质因数,并从 $x=\varphi(p)$ 开始,每次尝试除掉一个质因数看是否仍满足 $a^x\equiv 1\pmod p$。单次时间复杂度 $O(\log^2 p)$。

这样就可以在 $O(n^2+n\log^2 p)$ 的时间内完成建图。


考虑 $p$ 不是质数的情况。我们先找出 $p$ 的质因子 $p_0$,这部分可以 $O(\sqrt p)$ 暴力。

对于 $a$ 不是 $p_0$ 的倍数的情况,由于 $p$ 存在原根,因此可以按照上述方法来连边。

对于 $a$ 是 $p_0$ 的情况,令 $p=p_0^k$,则 $a^k$ 必定为 $0$。因此这部分可以暴力进行连边。

这样建图部分就完成了。

我们还需要对每个点计算能到达它的点的个数。从建图的过程中可以看出,如果 $a$ 到 $b$ 有边,$b$ 到 $c$ 有边,则 $a$ 到 $c$ 也一定有边。因此直接根据边来计算即可。

总时间复杂度 $O(n^2+\sqrt p+n\log^2 p)$。

Code:

#include<iostream>
#include<vector>
#include<queue>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=5005,mod=998244353;
int n,md,p,phi,ord[N],dfn[N],low[N],SCC,in[N],idx,sta[N],top,sz[N],deg[N],sm[N],ans,_2[N],_ct[N];
bool vis[N],vv[N][N];
vector<int>a,b,fc;
vector<int>G[N],_G[N];
inline int id(int x){
    const int s=lower_bound(a.begin(),a.end(),x)-a.begin();
    if(s==(int)a.size()||a[s]!=x)return-1;
    return s;
}
inline int pow(int a,int b){
    int res=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)res=(LL)res*a%md;
    return res;
}
void solve(int n){
    for(int i=2;i*i<=n;++i)if(n%i==0){
        fc.push_back(i);
        while(n%i==0)n/=i;
    }
    if(n>1)fc.push_back(n);
}
void dfs(int now){
    dfn[now]=low[now]=++idx;
    sta[++top]=now;
    vis[now]=1;
    for(int to:G[now])if(!dfn[to])dfs(to),low[now]=min(low[now],low[to]);
    else if(vis[to])low[now]=min(low[now],dfn[to]);
    if(low[now]==dfn[now]){
        ++SCC;
        int v;
        do{
            v=sta[top--];
            vis[v]=0;
            in[v]=SCC;
            sz[SCC]+=_ct[v];
        }while(v!=now);
    }
}
void bfs(){
    static queue<int>q;
    for(int i=1;i<=SCC;++i)if(!deg[i])q.push(i);
    const int nn=b.size();
    while(!q.empty()){
        int u=q.front();q.pop();
        ans=(ans+_2[nn-sm[u]]*(_2[sz[u]]-1LL))%mod;
        for(int to:_G[u])if(!--deg[to])q.push(to);
    }
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>md;
    for(int i=*_2=1;i<=n;++i)_2[i]=(_2[i-1]<<1)%mod;
    for(int x;n--;a.push_back(x))cin>>x;
    b=a;
    sort(a.begin(),a.end()),a.erase(unique(a.begin(),a.end()),a.end()),n=a.size();
    for(int i:b)++_ct[id(i)];
    p=md;
    for(int i=2;i*i<=md;++i)if(md%i==0){p=i;break;}
    phi=md/p*(p-1);
    solve(phi);
    for(int i=0;i<n;++i)if(a[i]%p){
        int&f=ord[i];f=phi;
        for(int x:fc)while(f%x==0&&pow(a[i],f/x)==1)f/=x;
    }
    for(int i=0;i<n;++i)if(a[i]%p){
        for(int j=0;j<n;++j)if(i!=j&&ord[j]&&ord[i]%ord[j]==0)G[i].push_back(j);
    }else{
        const int k=a[i];
        for(int s=(LL)k*k%md;s;s=(LL)s*k%md){
            int x=id(s);
            if(x!=-1)G[i].push_back(x);
        }
    }
    for(int i=0;i<n;++i)if(!dfn[i])dfs(i);
    for(int i=1;i<=SCC;++i)sm[i]=sz[i];
    for(int i=0;i<n;++i)for(int to:G[i])if(in[i]!=in[to]&&!vv[in[i]][in[to]])++deg[in[to]],_G[in[i]].push_back(in[to]),sm[in[to]]+=sz[in[i]],vv[in[i]][in[to]]=1;
    bfs();
    cout<<ans<<'\n';
    return 0;
}