【洛谷P3413】SAC#1 - 萌数

数位 DP

题目大意:

给定 $l,r$,问 $[l,r]$ 里有多少整数,满足存在一个长度大于 $1$ 的子串是回文串(不包含前导 $0$)。

题解:

我们只需要考虑长度为 $2$ 或 $3$ 的回文串是否出现,因为更长的去掉头尾两个字符一定还是一个回文串。因此记录状态只需记录两位。

从高位往低位数位 dp。设 $f[n][w][t1][t2]$ 表示当前考虑的最后两位是 $n$,有效位数为 $w$($w\in[0,3]$),$t1$ 表示当前是否已经有过回文串,$t2$ 表示当前已经确定的位是否和上界完全匹配。

记录有效位数是为了判断回文,如 $0$ 在有效位数为 $2$ 的时候是算回文串的($00$)。

转移时枚举每个状态,枚举下一位是什么,判断一下回文,进行转移即可。

注意状态不要太多,否则转移的时候会超时(原来 $n$ 记录的三位)。

Code:

#include<iostream>
#include<string>
#include<sstream>
#include<algorithm>
#include<cstring>
using namespace std;
const int md=1e9+7;
inline bool check(int n,int w){
    if(w==2)return(n/10)==(n%10);
    if(w==3)return(n/100)==(n%10)||check(n/10,2)||check(n%100,2);
    return 0;
}
string l,r;
int dp[2][101][4][2][2];
inline void upd(int&a){a+=a>>31&md;}
inline void inc(int&a,int b){upd(a+=b-md);}
int ask(string s){
    if(s.length()<4){
        int n,ret=0;
        istringstream in(s);
        in>>n;
        for(int i=1;i<=n;++i)
            ret+=check(i,i<10?1:(i<100?2:3));
        return ret;
    }
    memset(dp,0,sizeof dp);
    int mx=(s[0]^'0')*10+(s[1]^'0');
    for(int i=0;i<=mx;++i){
        int w=(i==0?0:(i<10?1:2));
        ++dp[0][i][w][check(i,w)][i==mx];
    }
    for(int wg=2;wg<s.length();++wg){
        memset(dp[1],0,sizeof*dp);
        int c=s[wg]^'0';
        for(int num=0;num<100;++num)
            for(int ws=0;ws<=3;++ws)
                for(int ok=0;ok<2;++ok)
                    for(int bg=0;bg<2;++bg){
                        for(int nxt=0;nxt<10;++nxt){
                            int now=(num*10+nxt)%1000,nws=now==0&&ws==0?0:min(ws+1,3);
                            int nok=ok||check(now,nws);
                            inc(dp[1][now%100][nws][nok][bg&&nxt==c],dp[0][num][ws][ok][bg]);
                            if(bg&&nxt==c)break;
                        }
                    }
        memcpy(dp[0],dp[1],sizeof*dp);
    }
    int ret=0;
    for(int i=0;i<100;++i)
        for(int w=0;w<=3;++w)
            for(int bg=0;bg<2;++bg)
                upd(ret+=dp[0][i][w][1][bg]-md);
    return ret;
}
int main(){
    cin>>l>>r;
    --l[l.length()-1];
    for(int i=l.length()-1;i>0&&l[i]<'0';--i)
        --l[i-1],l[i]+=10;
    while(l.length()>1&&l[0]=='0')l.erase(0,1);
    cout<<(ask(r)-ask(l)+md)%md<<endl;
    return 0;
}