【洛谷P6301】集合

压位 $\omega$ 叉树。

题目大意:

你需要实现一种数据结构,用来维护一个没有重复元素的整数集合(std::set),支持以下操作:

  1. 插入一个数 $x$。
  2. 删除一个数 $x$。
  3. 求集合中最小的元素。如果集合为空则为 -1
  4. 求集合中最大的元素。如果集合为空则为 -1
  5. 求集合中的元素个数。
  6. 判断 $x$ 是否在集合中。
  7. 给定 $x$,查询小于 $x$ 的最大数。如果不存在则为 -1
  8. 给定 $x$,查询大于等于 $x$ 的最小数。如果不存在则为 -1

强制在线。

题解:

其实这个 trick 很早就有,某些毒瘤经常拿它来搞事情。

大概的思路就是,我们搞一棵 $\omega$ 叉的树来维护这些信息。其中 $\omega$ 是计算机能并行处理的二进制位数,一般为 $32$ 或 $64$。这棵树的高度是 $\frac{\log ⁡n}{\log \omega}$ 的。

对于每个结点,我们维护它的最小值、最大值,以及每一个儿子结点是否有元素出现。我们显然可以把每一个儿子结点是否有元素出现,用一个 $\omega$ 位无符号整数 $w$ 来压。

考虑如何实现这几个操作。

对于操作 $1$,我们在 $\omega$ 叉树上递归。对于一个结点,我们可以通过位运算轻易地找出它应该在第几个儿子里,并且往下递归即可。在最后一层,也可以用位运算快速标记这个数。时间复杂度 $O(\frac{\log ⁡n}{\log \omega})$。操作 $2$ 和操作 $6$ 都是类似的。

操作 $5$ 可以直接用一个变量记录。

对于操作 $7$,我们还是考虑在 $\omega$ 叉树上递归。假设它的前 $\log\omega$ 位是 $x$,那么它的前驱的前 $\log\omega$ 位必定 $\leq x$。

对于等于 $x$ 的情况,我们还需要看下一层的最小值是否小于现在要查询的值,如果是,则往下递归。如果不是,则我们需要找到当前的 $w$ 中,小于 $x$ 的最大的是 $1$ 的位置即可。这个也可以通过位运算得到。然后往下递归即可。时间复杂度 $O(\frac{\log ⁡n}{\log \omega})$。操作 $8$ 是类似的。

对于操作 $3,4$,由于我们已经对每层结点记录了最值,所以直接访问即可。最值的维护也需要涉及到前驱 $1$ 的出现位置以及后继 $1$ 的出现位置。

这样所有的操作,我们都在不高于 $O(\frac{\log ⁡n}{\log \omega})$ 的时间复杂度内完成了。

总时间复杂度 $O(m\cdot \frac{\log ⁡n}{\log \omega})$,空间复杂度 $O(n)$。

Code:

#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
#include<cstdio>
#include<cctype>
typedef unsigned long long uLL;
int n,m,last,SZ;
long long ans;
char buf[(int)1e8],*ss=buf;
inline int init(){buf[fread(buf,1,(int)1e8-1,stdin)]='\n';fclose(stdin);return 0;}
const int __START__=init();
inline int readint(){
    int d=0,f=0;
    while(!isdigit(*ss))f=*ss++=='-';
    while(isdigit(*ss))d=d*10+(*ss++^'0');
    return f?-d:d;
}
inline void set(uLL&x,int v){x|=1uLL<<v;}
inline void reset(uLL&x,int v){x&=~(1uLL<<v);}
inline int lbt(uLL x){return __builtin_ctzll(x);}
inline int hbt(uLL x){return 63-__builtin_clzll(x);}
inline int pre(const uLL&w,int x){return hbt(w<<(63-x)<<1)-(64-x);}
inline int suc(const uLL&w,int x){return lbt(w>>x>>1)+x+1;}
template<const int size,const int bit>
struct TREE{
    uLL w;int min,max;
    TREE<size/64,bit-6>s[64];
    inline void init(){
        min=1<<30,max=-1;
        for(int i=0;i<64;++i)s[i].init();
    }
    inline bool any(){return w;}
    inline int idx(int x){return x&((1u<<bit)-1);}
    inline bool set(int x){
        bool res=s[x>>bit].set(idx(x));
        ::set(w,x>>bit);
        const int L=lbt(w),H=hbt(w);
        min=L<<bit|s[L].min,max=H<<bit|s[H].max;
        return res;
    }
    inline bool reset(int x){
        bool res=s[x>>bit].reset(idx(x));
        if(!s[x>>bit].any())::reset(w,x>>bit);
        if(any()){
            const int L=lbt(w),H=hbt(w);
            min=L<<bit|s[L].min,max=H<<bit|s[H].max;
        }else min=1<<30,max=-1;
        return res;
    }
    inline int pre(int x){
        if(x<=min)return 1<<30;
        if(s[x>>bit].min<idx(x))return s[x>>bit].pre(idx(x))|(x>>bit<<bit);
        int p=::pre(w,x>>bit);
        return s[p].max|(p<<bit);
    }
    inline int suc(int x){
        if(x>max)return-1;
        if(s[x>>bit].max>=idx(x))return s[x>>bit].suc(idx(x))|(x>>bit<<bit);
        int c=::suc(w,x>>bit);
        return s[c].min|(c<<bit);
    }
    inline bool find(int x){return(w>>(x>>bit)&1)?s[x>>bit].find(idx(x)):0;}
};
template<>
struct TREE<64,0>{
    uLL w;int min,max;
    inline void init(){min=1<<30,max=-1;}
    inline bool any(){return w;}
    inline bool set(int x){
        if(w>>x&1)return 0;
        ::set(w,x),min=lbt(w),max=hbt(w);
        return 1;
    }
    inline bool reset(int x){
        if(w>>x&1){
            ::reset(w,x);
            if(any())min=lbt(w),max=hbt(w);else min=1<<30,max=-1;
            return 1;
        }
        return 0;
    }
    inline int pre(int x){return::pre(w,x);}
    inline int suc(int x){return lbt(w>>x)+x;}
    inline bool find(int x){return w>>x&1;}
};
TREE<1<<30,24>s;
int main(){
    n=readint(),m=readint();
    s.init();
    while(m--){
        int op=readint();
        switch(op){
            case 1:SZ+=s.set(readint()+last);break;
            case 2:SZ-=s.reset(readint()+last);break;
            case 3:last=s.min;if(last==(1<<30))last=-1;ans+=last;break;
            case 4:last=s.max,ans+=last;break;
            case 5:ans+=(last=SZ);break;
            case 6:ans+=(last=s.find(readint()+last));break;
            case 7:last=s.pre(readint()+last);if(last==(1<<30))last=-1;ans+=last;break;
            case 8:last=s.suc(readint()+last),ans+=last;
        }
    }
    printf("%lld\n",ans);
    return 0;
}