一、引入

随机数据中,BST 一次操作的期望复杂度为 \(\mathcal{O}(\log n)\)

然而,BST 很容易退化,例如在 BST 中一次插入一个有序序列,将会得到一条链,平均每次操作的复杂度为 \(\mathcal{O}(n)\)。我们称这种左右子树大小相差很大的 BST 是“不平衡”的。

有很多方法可以维持 BST 的平衡,从而产生了各种平衡树。

Treap 就是常见平衡树中的一种。

二、简介

满足 BST 性质且中序遍历为相同序列的二叉查找树是不唯一的。这些二叉查找树是等价的,它们维护的是相同的一组数值。在这些二叉查找树上执行同样的操作,将得到相同的结果。

因此,我们可以在维持 BST 性质的基础上,通过改变二叉查找树的 形态,使得树上每个节点的左右子树大小达到平衡,从而使整棵树的深度维持在 \(\mathcal{O}(\log n)\) 级别。

Treap 改变形态并保持 BST 性质的方式为“旋转”,并且保持平衡而不至于退化为链。

Treap=Tree+Heap。Treap 是利用堆的性质来维护平衡的一种平衡树。对每个节点额外存储一个随机值,根据随机值调整 Treap 的形态,使其满足 BST 性质外,还满足父节点的随机值 \(\geq\) 子节点的随机值。

三、Treap

前面说过,为了使 Treap 保持平衡而进行旋转操作。

旋转的本质是将某个节点上移一个位置。旋转需要保证 :

  • 整棵树的中序遍历不变(不能破坏 BST 的性质)。

  • 受影响的节点维护的信息依然正确有效。

每个节点在建立时,赋予其一个随机值,通过旋转操作使得随机值满足大根堆的性质。这样可以使得树高期望保持在  \(\mathcal{O}(\log n)\) 。

1. 旋转操作

在 Treap 中的旋转分为两种:左旋 和 右旋

注意:某些书籍把左右旋定义为一个节点绕其父节点向左或向右旋转。而这里的 Treap 代码仅记录左右子节点,没有记录父节点,方便起见,统一以“旋转前处于父节点位置”(旋转后处于子节点位置)的节点作为左右旋的作用对象。

「算法笔记」旋转 Treap

以右旋为例。如图所示,在初始情况下,\(x\)\(y\) 的左子节点,\(A\)\(B\) 分别是 \(x\) 的左右子树,\(C\)\(y\) 的右子树。

“右旋”操作在保持 BST 性质的基础上,把 \(x\) 变为 \(y\) 的父节点。因为 \(x\) 的关键码小于 \(y\) 的关键码,所以 \(y\) 应该作为 \(x\) 的右子节点。

\(x\) 变成 \(y\) 的父节点后,\(y\) 的左子树就空了出来,于是 \(x\) 原来的右子树 \(B\) 就恰好作为 \(y\) 的左子树。

  • 左旋:将右儿子提到当前节点,自己作为右儿子的左儿子,右儿子原来的左儿子变成自己新的右儿子。

  • 右旋:将左儿子提到当前节点,自己作为左儿子的右儿子,左儿子原来的右儿子变成自己新的左儿子。

右旋将左儿子上移,左旋将右儿子上移。左右旋并 没有本质区别。其目的相同,即将指定节点上移一个位置。

旋转后的二叉树仍满足 BST 的性质。

void zig(int &p){    //右旋操作。zig(p) 可以理解成把 p 的左子节点绕着 p 向右旋转。 
    int q=lc[p];
    lc[p]=rc[q],rc[q]=p,p=q;    //注意 p 是引用 
} 
void zag(int &p){    //左旋操作。zag(p) 可以理解成把 p 的右子节点绕着 p 向左旋转。 
    int q=rc[p];
    rc[p]=lc[q],lc[q]=p,p=q;    //注意 p 是引用 
}

2. 随机权值

合理的旋转操作可使 BST 更“平衡”。如下图,经过一些旋转操作,这棵 BST 变得比较平衡了。

「算法笔记」旋转 Treap

在随机数据下,普通的 BST 就是趋近平衡的。Treap 的思想就是利用“随机”来创造平衡条件。因为在旋转过程中必须维持 BST 性质,所以 Treap 就把“随机”作用在堆性质上。

具体来说,Treap 在插入每个新节点时,给该节点随机生成一个额外的权值。当某个节点不满足大根堆性质时,就执行旋转操作,使该点与其父节点的关系发生对换。

每次删除/插入时通过随机的值决定要不要旋转即可,其他操作与 BST 类似。

特别地,对于删除操作,由于 Treap 支持旋转,我们可以直接找到需要删除的节点,并把它 向下旋转成叶节点,最后直接删除。这样就避免了采取类似普通 BST 的删除方法可能导致的节点信息更新、堆性质维护等复杂问题。

Treap 通过适当的旋转操作,在 维持节点关键码满足 BST 性质的同时,还使每个节点上随机生成的额外权值满足大根堆性质。Treap 是一种平衡的二叉查找树,检索、插入、求前驱后继以及删除节点的时间复杂度都是 \(\mathcal{O}(\log n)\)。

四、模板

Luogu P3369  普通平衡树

Problem:需要写一种数据结构,来维护一些数,其中需要提供以下操作:

  1. 插入数值 \(x\)
  2. 删除数值 \(x\)(若有多个相同的数,应只删除一个)
  3. 查询数值 \(x\) 的排名(若有多个相同的数,应输出最小的排名)
  4. 查询排名为 \(x\) 的数
  5. 求数值 \(x\) 的前驱(前驱定义为小于 \(x\) 的最大的数)
  6. 求数值 \(x\) 的后继(后继定义为大于 \(x\) 的最小的数)

\(1\leq n \leq 10^5,|x| \leq 10^7\)

Solution:平衡树模板题,用 Treap 实现即可。

数据中可能有相同的数值。记 \(cnt(u)\) 表示节点 \(u\) 对应数值的出现次数,初始时为 \(1\)。(这里的“对应数值”就是关键码)

若插入已经存在的数值,就直接把 \(cnt\) 值加 \(1\)。删除时,若 \(cnt(u)>1\),则把 \(cnt(u)\)\(1\);否则删除该节点。

再记 \(sz(u)\) 表示以 \(u\) 为根的子树中所有节点的 \(cnt\) 之和。在插入或删除时从下往上更新 \(sz\) 信息。另外,在旋转操作时,也需要同时修改 \(sz\)

在 BST 检索的基础上,通过判断 \(sz(lc(u))\)\(sz(rc(u))\) 的大小,选择适当的一侧递归,就能查询排名了。

在插入和删除操作时,Treap 的形态会发生变化,一般使用递归实现,以便于在回溯时更新 Treap 上存储的 \(sz\) 等信息。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5; 
int n,opt,x,tot,rt,lc[N],rc[N],val[N],rnd[N],sz[N],cnt[N],ans;    //rnd(u) 表示节点 u 的随机值 
void upd(int p){
    sz[p]=sz[lc[p]]+sz[rc[p]]+cnt[p];
}
int getnew(int k){
    val[++tot]=k,rnd[tot]=rand(),cnt[tot]=sz[tot]=1;
    return tot;
}
void build(){
    getnew(-1e18),getnew(1e18),rt=1,rc[1]=2,upd(rt);
} 
void rotate(int &p,int dir){    //dir= 0 右旋  1 左旋 
    int q=!dir?lc[p]:rc[p];
    if(!dir) lc[p]=rc[q],rc[q]=p,p=q,upd(rc[p]),upd(p);
    else rc[p]=lc[q],lc[q]=p,p=q,upd(lc[p]),upd(p);
}
void insert(int &p,int k){
    if(!p){p=getnew(k);return ;}    
    if(val[p]==k){cnt[p]++,upd(p);return ;}
    if(k<val[p]){insert(lc[p],k);if(rnd[p]<rnd[lc[p]]) rotate(p,0);}    //不满足堆性质,右旋 
    else{insert(rc[p],k);if(rnd[p]<rnd[rc[p]]) rotate(p,1);}    //不满足堆性质,左旋 
    upd(p);
} 
void del(int &p,int k){
    if(!p) return ;
    if(val[p]==k){    //检索到 k 
        if(cnt[p]>1){cnt[p]--,upd(p);return ;}     //有重复,让 cnt 值减 1 即可 
        if(lc[p]||rc[p]){    //不是叶子节点,向下旋转 
            if(!rc[p]||rnd[lc[p]]>rnd[rc[p]]) rotate(p,0),del(rc[p],k);
            else rotate(p,1),del(lc[p],k);
            upd(p);
        }
        else p=0; return ;    //叶子节点直接删除 
    }
    del(k<val[p]?lc[p]:rc[p],k),upd(p);
}
int rank(int p,int k){
    if(!p) return 0;
    if(val[p]==k) return sz[lc[p]]+1;
    return k<val[p]?rank(lc[p],k):rank(rc[p],k)+sz[lc[p]]+cnt[p];
}
int Kth(int p,int rk){
    if(!p) return 1e18;
    if(sz[lc[p]]>=rk) return Kth(lc[p],rk);
    if(sz[lc[p]]+cnt[p]>=rk) return val[p];
    return Kth(rc[p],rk-sz[lc[p]]-cnt[p]); 
}
int pre(int k){
    int ans=1,p=rt;
    while(p){
        if(val[p]==k){
            if(lc[p]>0){p=lc[p]; while(rc[p]>0) p=rc[p]; ans=p;}    //左子树上一直向右走 
            break;
        }
        if(val[p]<k&&val[p]>val[ans]) ans=p;
        p=k<val[p]?lc[p]:rc[p]; 
    }
    return val[ans];
}
int nxt(int k){
    int ans=2,p=rt;
    while(p){
        if(val[p]==k){
            if(rc[p]>0){p=rc[p]; while(lc[p]>0) p=lc[p]; ans=p;}    //右子树上一直向左走 
            break;
        }
        if(val[p]>k&&val[p]<val[ans]) ans=p;
        p=k<val[p]?lc[p]:rc[p]; 
    }
    return val[ans];
}
signed main(){
    scanf("%lld",&n),build();
    while(n--){
        scanf("%lld%lld",&opt,&x),ans=-1;
        if(opt==1) insert(rt,x);
        else if(opt==2) del(rt,x);
        else if(opt==3) ans=rank(rt,x)-1;
        else if(opt==4) ans=Kth(rt,x+1); 
        else if(opt==5) ans=pre(x);
        else ans=nxt(x);
        if(~ans) printf("%lld\n",ans);
    }
    return 0;
}

少了一点压行的版本:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5; 
int n,opt,x,tot,rt,lc[N],rc[N],val[N],rnd[N],sz[N],cnt[N],ans;    //rnd(u) 表示节点 u 的随机值 
void upd(int p){
    sz[p]=sz[lc[p]]+sz[rc[p]]+cnt[p];
}
int getnew(int k){
    val[++tot]=k,rnd[tot]=rand(),cnt[tot]=sz[tot]=1;
    return tot;
}
void build(){
    getnew(-1e18),getnew(1e18),rt=1,rc[1]=2,upd(rt);
} 
void zig(int &p){    //右旋 
    int q=lc[p];
    lc[p]=rc[q],rc[q]=p,p=q,upd(rc[p]),upd(p);
}
void zag(int &p){    //左旋 
    int q=rc[p];
    rc[p]=lc[q],lc[q]=p,p=q,upd(lc[p]),upd(p);
}
void insert(int &p,int k){
    if(!p){p=getnew(k);return ;}    
    if(val[p]==k){cnt[p]++,upd(p);return ;}
    if(k<val[p]){
        insert(lc[p],k);
        if(rnd[p]<rnd[lc[p]]) zig(p);    //不满足堆性质,右旋 
    }
    else{
        insert(rc[p],k);
        if(rnd[p]<rnd[rc[p]]) zag(p);    //不满足堆性质,左旋 
    }
    upd(p);
} 
void del(int &p,int k){
    if(!p) return ;
    if(val[p]==k){    //检索到 k 
        if(cnt[p]>1){cnt[p]--,upd(p);return ;}     //有重复,让 cnt 值减 1 即可 
        if(lc[p]||rc[p]){    //不是叶子节点,向下旋转 
            if(!rc[p]||rnd[lc[p]]>rnd[rc[p]]) zig(p),del(rc[p],k);
            else zag(p),del(lc[p],k);
            upd(p);
        }
        else p=0; return ;    //叶子节点直接删除 
    }
    del(k<val[p]?lc[p]:rc[p],k),upd(p);
}
int rank(int p,int k){
    if(!p) return 0;
    if(val[p]==k) return sz[lc[p]]+1;
    if(k<val[p]) return rank(lc[p],k);
    return rank(rc[p],k)+sz[lc[p]]+cnt[p];
}
int Kth(int p,int rk){
    if(!p) return 1e18;
    if(sz[lc[p]]>=rk) return Kth(lc[p],rk);
    if(sz[lc[p]]+cnt[p]>=rk) return val[p];
    return Kth(rc[p],rk-sz[lc[p]]-cnt[p]); 
}
int pre(int k){
    int ans=1,p=rt;
    while(p){
        if(val[p]==k){
            if(!(p=lc[p])) break;
            while(rc[p]>0) p=rc[p];    //左子树上一直向右走 
            ans=p;break;
        }
        if(val[p]<k&&val[p]>val[ans]) ans=p;
        p=k<val[p]?lc[p]:rc[p]; 
    }
    return val[ans];
}
int nxt(int k){
    int ans=2,p=rt;
    while(p){
        if(val[p]==k){
            if(!(p=rc[p])) break;
            while(lc[p]>0) p=lc[p];    //右子树上一直向左走
            ans=p;break;
        }
        if(val[p]>k&&val[p]<val[ans]) ans=p;
        p=k<val[p]?lc[p]:rc[p]; 
    }
    return val[ans];
}
signed main(){
    scanf("%lld",&n),build();
    while(n--){
        scanf("%lld%lld",&opt,&x),ans=-1;
        if(opt==1) insert(rt,x);
        else if(opt==2) del(rt,x);
        else if(opt==3) ans=rank(rt,x)-1;
        else if(opt==4) ans=Kth(rt,x+1); 
        else if(opt==5) ans=pre(x);
        else ans=nxt(x);
        if(~ans) printf("%lld\n",ans);
    }
    return 0;
}

注:rank(rt,x)-1Kth(rt,x+1) 的加减一是因为初始时额外插入了关键码为 \(+\infty\)\(−\infty\) 的节点。

upd:Treap 求前驱后继这么写更简洁(模板里那个代码懒得改了)

int pre(int k){
    int p=rt,ans=0;
    while(p){
        if(val[p]<k) ans=p,p=rc[p];
        else p=lc[p];
    }
    return ans;
}
int nxt(int k){
    int p=rt,ans=0;
    while(p){
        if(val[p]>k) ans=p,p=lc[p];
        else p=rc[p];
    }
    return ans;
}

五、参考资料

  • 《算法竞赛进阶指南》(大棒子)

相关文章:

  • 2021-12-08
  • 2022-02-17
  • 2022-12-23
  • 2022-02-08
  • 2021-11-09
  • 2021-12-14
  • 2022-01-08
  • 2022-12-23
猜你喜欢
  • 2022-02-20
  • 2022-12-23
  • 2022-02-17
  • 2022-02-09
  • 2021-10-12
  • 2021-08-30
  • 2022-12-23
相关资源
相似解决方案