【Codeforces 585F】Digits of Number Pi

AC 自动机上数位 DP。

题目大意:

给定一个长度为 $n$ 的数字串 $s$ 和两个长度为 $d$ 的不含前导零数字串 $L,R$。

求有多少个长度为 $d$ 的不含前导零的数字串满足:

  • 其数值大小在 $[L,R]$ 之间。
  • 其存在长度至少为 $\lfloor\frac d 2\rfloor$ 的子串是 $s$ 的子串。

题解:

将 $s$ 的每个长度为 $\lfloor\frac d 2\rfloor$ 的子串拿出来,对这些串建 AC 自动机。

然后在 AC 自动机上跑数位 DP 即可。

令 $f[i][u][0/1][0/1]$ 表示当前考虑到第 $i$ 位,匹配到自动机上的 $u$ 号结点,是否已经遇到过一个匹配点(即存在子串是 $s$ 的子串),当前匹配到的位是否和上界相等。

时间复杂度 $O(nd^2)$。

Code:

#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
using namespace std;
const int N=50005,md=1e9+7;
int ch[N][10],nod,fail[N];
bool vis[N];
char L[55],R[55],S[N];
int d,k,n,ans;
inline void upd(int&a){a+=a>>31&md;}
void insert(char*s){
    int nw=0;
    for(int i=0;i<k;++i){
        const int x=s[i]^'0';
        if(!ch[nw][x])ch[nw][x]=++nod;
        nw=ch[nw][x];
    }
    vis[nw]=1;
}
queue<int>q;
void build_AC(){
    for(int i=0;i<10;++i)
    if(ch[0][i])q.push(ch[0][i]);
    while(!q.empty()){
        int u=q.front();
        q.pop();
        for(int c=0;c<10;++c)if(ch[u][c]){
            int v=ch[u][c];
            fail[v]=ch[fail[u]][c];
            q.push(v);
            vis[v]|=vis[fail[v]];
        }else ch[u][c]=ch[fail[u]][c];
    }
}
int f[52][N][2][2];
int calc(char*R){
    if(R[d-1]=='0')return 0;
    memset(f,0,sizeof f);
    for(int i=1;i<(R[d-1]^'0');++i)++f[d-1][ch[0][i]][vis[ch[0][i]]][0];
    f[d-1][ch[0][R[d-1]^'0']][vis[ch[0][R[d-1]^'0']]][1]=1;
    for(int i=d-2;~i;--i)
    for(int t=0;t<=nod;++t)
    for(int pp=0;pp<2;++pp)
    for(int zt=0;zt<2;++zt)if(f[i+1][t][pp][zt]){
        const int&F=f[i+1][t][pp][zt];
        for(int c=0;c<10;++c){
            if(zt&&c>(R[i]^'0'))break;
            const int to=ch[t][c];
            const int npp=pp||vis[to],nzt=zt&&c==(R[i]^'0');
            upd(f[i][to][npp][nzt]+=F-md);
        }
    }
    int ret=0;
    for(int i=0;i<=nod;++i)
    upd(ret+=f[0][i][1][0]-md),upd(ret+=f[0][i][1][1]-md);
    return ret;
}
int main(){
    scanf("%s%s%s",S,L,R);
    d=strlen(L),k=d>>1,n=strlen(S);
    for(int i=0;i+k<=n;++i)
    insert(S+i);
    build_AC();
    reverse(L,L+d),reverse(R,R+d);
    --L[0];
    for(int i=0;i<d&&L[i]<'0';++i)
    L[i]+=10,--L[i+1];
    upd(ans=calc(R)-calc(L));
    printf("%d\n",ans);
    return 0;
}