Splay
参考:https://tiger0132.blog.luogu.org/slay-notes
普通模板:
const int N = 1e5 + 10; int ch[N][2], val[N], cnt[N], fa[N], sz[N], lazy[N], ncnt = 0, rt = 0; int n, m; inline int ck(int x) { return ch[fa[x]][1] == x; } inline void push_up(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; } ///区间反转 inline void push_down(int x) { if(lazy[x]) { swap(ch[x][0], ch[x][1]); lazy[ch[x][0]] ^= 1; lazy[ch[x][1]] ^= 1; lazy[x] = 0; } } void Rotate(int x) { int y = fa[x], z = fa[y]; push_down(y), push_down(x);///区间反转 int k = ck(x), w = ch[x][k^1]; ch[y][k] = w, fa[w] = y; ch[z][ck(y)] = x, fa[x] = z; ch[x][k^1] = y, fa[y] = x; push_up(y), push_up(x); } void Splay(int x, int goal = 0) { push_down(x);///区间反转 while(fa[x] != goal) { int y = fa[x], z = fa[y]; if(z != goal) { if(ck(x) == ck(y)) Rotate(y); else Rotate(x); } Rotate(x); } if(!goal) rt = x; } void Find(int x) { if(!rt) return ; int cur = rt; while(ch[cur][x>val[cur]] && x != val[cur]) cur = ch[cur][x>val[cur]]; Splay(cur); } void Insert(int x) { int cur = rt, p = 0; while(cur && val[cur] != x) { p = cur; cur = ch[cur][x>val[cur]]; } if(cur) cnt[cur]++; else { cur = ++ncnt; if(p) ch[p][x>val[p]] = cur; fa[cur] = p; ch[cur][0] = ch[cur][1] = 0; val[cur] = x; cnt[cur] = sz[cur] = 1; } Splay(cur); } int Kth(int k) { int cur = rt; while(true) { push_down(cur); ///区间反转 if(ch[cur][0] && k <= sz[ch[cur][0]]) cur = ch[cur][0]; else if(k > sz[ch[cur][0]] + cnt[cur]) k -=sz[ch[cur][0]] + cnt[cur], cur = ch[cur][1]; else return cur; } } inline int get_min(int x) { while(x && ch[x][0]) x = ch[x][0]; return x; } inline int get_max(int x) { while(x && ch[x][1]) x = ch[x][1]; return x; } int Pre(int x) { Find(x); if(val[rt] < x) return rt; int cur = ch[rt][0]; while(ch[cur][1]) cur = ch[cur][1]; return cur; } int Succ(int x) { Find(x); if(val[rt] > x) return rt; int cur = ch[rt][1]; while(ch[cur][0]) cur = ch[cur][0]; return cur; } void Remove(int x) { int last = Pre(x), next = Succ(x); Splay(last), Splay(next, last); int del = ch[next][0]; if(cnt[del] > 1) cnt[del]--, Splay(del); else ch[next][0] = 0, push_up(next), push_up(last); } ///区间反转 void Reverse(int l, int r) { int x = Kth(l), y = Kth(r+2); Splay(x), Splay(y, x); lazy[ch[y][0]] ^= 1; } void Output(int x) { push_down(x); if(ch[x][0]) Output(ch[x][0]); if(1 <= val[x] && val[x] <= n) printf("%d ", val[x]); if(ch[x][1]) Output(ch[x][1]); } void delete_root() { if(ch[rt][1]) { int cur = ch[rt][1]; while(cur && ch[cur][0]) cur = ch[cur][0]; Splay(cur, rt); ch[cur][0] = ch[rt][0]; fa[ch[cur][0]] = cur; rt = cur; } else rt = ch[rt][0]; fa[rt] = 0; if(rt) push_up(rt); } inline int NewNode(int x) { int cur; if(q.empty()) cur = ++ncnt; else cur = q.front(), q.pop(); ///初始化 ch[cur][0] = ch[cur][1] = fa[cur] = lazy[cur] = 0; val[cur] = x; sz[cur] = cnt[cur] = 1; return cur; } void Recycle(int x) { if(ch[x][0]) Recycle(ch[x][0]); if(ch[x][1]) Recycle(ch[x][1]); if(x) q.push(x); } int build(int l, int r, int *a) { if(l > r) return 0; int mid = l+r >> 1, cur = NewNode(a[mid]); if(l == r) return cur; if(ch[cur][0] = build(l, mid-1, a)) fa[ch[cur][0]] = cur; if(ch[cur][1] = build(mid+1, r, a)) fa[ch[cur][1]] = cur; push_up(cur); return cur; } inline void init() { ncnt = rt = ch[0][0] = ch[0][1] = fa[0] = sz[0] = cnt[0] = val[0] = lazy[0] = 0; }
按排名插入模板(常数较小???也许以前的方法写搓了):
inline void Newnode(int &cur, int f, int a) { cur = ++ncnt; fa[cur] = f; val[cur] = a; ch[cur][0] = ch[cur][1] = 0; sz[cur] = cnt[cur] = 1; } void Insert(int x, int y) { int p = 0; if(!rt) { Newnode(rt, 0, y); return ; } if(!x) { p = rt; sz[p]++; while(ch[p][0]) p = ch[p][0], sz[p]++; Newnode(ch[p][0], p, y); Splay(ch[p][0]); return ; } int u = Kth(x); Splay(u); Newnode(rt, 0, y); ch[rt][1] = ch[u][1]; fa[ch[rt][1]] = rt; ch[u][1] = 0; ch[rt][0] = u; fa[u] = rt; push_up(u), push_up(rt); }
例题:
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 2e5 + 10; int ch[N][2], val[N], cnt[N], fa[N], sz[N], ncnt = 0, rt = 0; inline int ck(int x) { return ch[fa[x]][1] == x; } inline void push_up(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; } void Rotate(int x) { int y = fa[x], z = fa[y], k = ck(x), w = ch[x][k^1]; ch[y][k] = w, fa[w] = y; ch[z][ck(y)] = x, fa[x] = z; ch[x][k^1] = y, fa[y] = x; push_up(y), push_up(x); } void Splay(int x, int goal = 0) { while(fa[x] != goal) { int y = fa[x], z = fa[y]; if(z != goal) { if(ck(x) == ck(y)) Rotate(y); else Rotate(x); } Rotate(x); } if(!goal) rt = x; } void Find(int x) { if(!rt) return ; int cur = rt; while(ch[cur][x>val[cur]] && x != val[cur]) cur = ch[cur][x>val[cur]]; Splay(cur); } void Insert(int x) { int cur = rt, p = 0; while(cur && val[cur] != x) { p = cur; cur = ch[cur][x>val[cur]]; } if(cur) cnt[cur]++; else { cur = ++ncnt; if(p) ch[p][x>val[p]] = cur; fa[cur] = p; ch[cur][0] = ch[cur][1] = 0; val[cur] = x; cnt[cur] = sz[cur] = 1; } Splay(cur); } int Kth(int k) { int cur = rt; while(true) { if(ch[cur][0] && k <= sz[ch[cur][0]]) cur = ch[cur][0]; else if(k > sz[ch[cur][0]] + cnt[cur]) k -=sz[ch[cur][0]] + cnt[cur], cur = ch[cur][1]; else return cur; } } int Pre(int x) { Find(x); if(val[rt] < x) return rt; int cur = ch[rt][0]; while(ch[cur][1]) cur = ch[cur][1]; return cur; } int Succ(int x) { Find(x); if(val[rt] > x) return rt; int cur = ch[rt][1]; while(ch[cur][0]) cur = ch[cur][0]; return cur; } void Remove(int x) { int last = Pre(x), next = Succ(x); Splay(last), Splay(next, last); int del = ch[next][0]; if(cnt[del] > 1) cnt[del]--, Splay(del); else ch[next][0] = 0/*, push_up(next), push_up(last)*/; } int n, opt, x; int main() { Insert(INT_MIN); Insert(INT_MAX); scanf("%d", &n); for (int i = 1; i <= n; ++i) { scanf("%d %d", &opt, &x); if(opt == 1) Insert(x); else if(opt == 2) Remove(x); else if(opt == 3) Find(x), printf("%d\n", sz[ch[rt][0]]); else if(opt == 4) printf("%d\n", val[Kth(x+1)]); else if(opt == 5) printf("%d\n", val[Pre(x)]); else printf("%d\n", val[Succ(x)]); } return 0; }