【NOI2017】蚯蚓排队

哈希+链表

题目大意:

有 $n$ 个 $[1,6]$ 的数字,组成字符串集,一开始每个数字单独为一个串。

有三个操作:

  1. 将两个数首尾相连拼接为一个串。
  2. 将一个串从某个位置处断开。
  3. 给定串 $s$ 和正整数 $k$,令 $f(t)$ 表示串 $t$ 在字符串集中作为子串的出现次数。求 $s$ 所有长度为 $k$ 的子串$s’$ 的 $f(s’)$ 之积。

保证 2 操作最多 $1000$ 次,$1\leq k\leq 50$。

题解:

观察发现 $k$ 非常小,所以不难想到一个看起来非常暴力的做法。

用双向链表来维护每个串,每次合并两个串的时候,以 $k^2$ 的复杂度得到中间多出来的所有子串,并对其哈希,然后累加每个串的出现次数。分裂的时候同理删除即可。

然后查询的时候,对于每个长度为 $k$ 的子串,直接 $O(1)$ 查询其出现次数即可。那么查询的复杂度为线性。

累加串的出现次数的时候,我们另外再开一个哈希表,把串的哈希值传进去存储即可。期望每次 $O(1)$。

这里使用 C++11 自带的unordered_map会因常数原因而超时。实际上手写哈希表也不难。

那么时间复杂度 $O(mk^2+\sum|s|)$。

这样分析显然会炸。

实际上,如果没有 2 操作,其最终的总子串格式是 $nk$ 个,那么 1 操作的总复杂度就为 $O(nk)$。

而 2 操作的个数 $c\leq 1000$。因此只会最多产生 $O(ck^2)$ 的复杂度。

这样,总时间复杂度就变为 $O(nk+ck^2+\sum|s|)$。可以通过。

Code:

#include<iostream>
using namespace std;
typedef unsigned long long LL;
const int base=7,N=2e5+5;
int n,m,pre[N],nxt[N],val[N];
LL _[52];
struct Hash_Table{
    static const int M=19260819,md=M-2;
    int cnt[M],head[M],nxt[M],tot;
    LL val[M];
    Hash_Table(){tot=0;}
    inline int&operator[](const LL&k){
        const int x=k%md;
        for(int i=head[x];i;i=nxt[i])
            if(val[i]==k)return cnt[i];
        nxt[++tot]=head[x],head[x]=tot,val[tot]=k;
        return cnt[tot];
    }
    inline int operator[](const LL&k)const{
        const int x=k%md;
        for(int i=head[x];i;i=nxt[i])
            if(val[i]==k)return cnt[i];
        return 0;
    }
}ct;
void merge(int L,int R){
    int l=L,r;
    LL hs=val[L],bs=1;
    for(int l1=1;l1<50&&l;++l1,l=pre[l],bs*=base,hs+=val[l]*bs){
        LL H=hs*base+val[R];r=R;
        for(int r1=1;l1+r1<=50&&r;++r1,r=nxt[r],H=H*base+val[r])
            ++ct[H];
    }
    nxt[L]=R,pre[R]=L;
}
void split(int L){
    const int R=nxt[L];
    int l=L,r;
    LL hs=val[L],bs=1;
    for(int l1=1;l1<50&&l;++l1,l=pre[l],bs*=base,hs+=val[l]*bs){
        LL H=hs*base+val[R];r=R;
        for(int r1=1;l1+r1<=50&&r;++r1,r=nxt[r],H=H*base+val[r])
            --ct[H];
    }
    nxt[L]=pre[R]=0;
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>m;
    for(int i=*_=1;i<=50;++i)_[i]=_[i-1]*base;
    for(int i=1;i<=n;++i){
        cin>>val[i];
        ++ct[val[i]];
    }
    while(m--){
        int op;
        cin>>op;
        if(op==1){
            int x,y;
            cin>>x>>y;
            merge(x,y);
        }else
            if(op==2){
                int x;
                cin>>x;
                split(x);
            }
            else{
                int k;
                static char s[20000002];
                cin>>s>>k;
                LL hx=0;
                for(int i=0;i<k;++i)
                    hx=hx*base+(s[i]^'0');
                int ans=1;
                for(int l=0,r=k-1;s[r];++l){
                    ans=(LL)ans*ct[hx]%998244353;
                    if(ans==0)break;
                    hx=hx*base-_[k]*(s[l]^'0')+(s[++r]^'0');
                }
                cout<<ans<<'\n';
            }
    }
    return 0;
}