给定一棵二叉树,每个节点带有一个数值称为 “关键码”,则“BST 性质”指,对树中任意一个节点:

  1. 该节点的关键码不小于它的左子树中任意节点的关键码。
  2. 该节点的关键码不大于它的右子树中任意节点的关键码。

  满足以上性质的二叉树即为“二叉搜索树(Binary Search Tree,BST)”。二叉搜索树的中序遍历是一个关键码单调递增的节点序列。

Treap

  满足 BST 性质且中序遍历为相同序列得二叉搜索树不唯一,它们是等价的,可以在维持 BST 性质上,改变二叉搜索树形态,使得每个节点左右子树大小达到平衡,从而整棵树深度维持在 \(O(\log{n})\)

  改变形态并保持 BST 性质方法即“旋转”,基本的旋转有“左旋”与“右旋”,如下图所示。

  左右旋代码

1
2
3
4
5
6
7
8
9
10
void zig(int &p) { // 右旋
int q = a[p].l;
a[p].l = a[q].r; a[q].r = p;
p = q;
}
void zag(int &p) { // 左旋
int q = a[p].r;
a[p].r = a[q].l; a[q].l = p;
p = q;
}

  为处理关键码相同的情况,节点增加一个域 cnt,节点的 sz 代表以该节点为根的子树中所有 cnt 的和,不存在重复数值时,sz 就是子树大小。

P3369 【模板】普通平衡树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;

#define debug(x) cout << #x << " is " << x << endl
typedef pair<int, int> pii;
typedef long long ll;
const int INF = 0x3f3f3f3f, N = 1e5 + 5;

struct Treap {
int l, r;
int val, dat;
int cnt, sz;
}a[N];

int tot, rt, n;

int New(int val) {
a[++tot].val = val; a[tot].dat = rand(); a[tot].cnt = a[tot].sz = 1;
return tot;
}

void update(int p) {
a[p].sz = a[a[p].l].sz + a[a[p].r].sz + a[p].cnt;
}

void build() {
New(-INF), New(INF);
rt = 1, a[1].r = 2;
update(rt);
}

int getrank(int p, int val) {
if (!p) return 0;
if (val == a[p].val) return a[a[p].l].sz + 1;
if (val < a[p].val) return getrank(a[p].l, val);
return getrank(a[p].r, val) + a[a[p].l].sz + a[p].cnt;
}

int getval(int p, int rank) {
if (!p) return INF;
if (a[a[p].l].sz >= rank) return getval(a[p].l, rank);
if (a[a[p].l].sz + a[p].cnt >= rank) return a[p].val;
return getval(a[p].r, rank - a[a[p].l].sz - a[p].cnt);
}

void zig(int &p) {
int q = a[p].l;
a[p].l = a[q].r, a[q].r = p, p = q;
update(a[p].r), update(p);
}

void zag(int &p) {
int q = a[p].r;
a[p].r = a[q].l, a[q].l = p, p = q;
update(a[p].l), update(p);
}

void insert(int &p, int val) {
if (!p) { p = New(val); return; }
if (val == a[p].val) {
a[p].cnt++, update(p);
return;
}
if (val < a[p].val) {
insert(a[p].l, val);
if (a[p].dat < a[a[p].l].dat) zig(p); // 不满足堆性质,右旋
}
else {
insert(a[p].r, val);
if (a[p].dat < a[a[p].r].dat) zag(p); // 不满足堆性质,左旋
}
update(p);
}

int getpre(int val) {
int ans = 1, p = rt; // a[1].val = -INF;
while (p) {
if (val == a[p].val) {
if (a[p].l > 0) {
p = a[p].l;
while (a[p].r > 0) p = a[p].r;
ans = p;
}
break;
}
if (a[p].val < val && a[p].val > a[ans].val) ans = p;
p = val < a[p].val ? a[p].l : a[p].r;
}
return a[ans].val;
}

int getnext(int val) {
int ans = 2, p = rt; // a[2].val = INF;
while (p) {
if (val == a[p].val) {
if (a[p].r > 0) {
p = a[p].r;
while (a[p].l > 0) p = a[p].l;
ans = p;
}
break;
}
if (a[p].val > val && a[p].val < a[ans].val) ans = p;
p = val < a[p].val ? a[p].l : a[p].r;
}
return a[ans].val;
}

void remove(int &p, int val) {
if (!p) return;
if (val == a[p].val) {
if (a[p].cnt > 1) {
a[p].cnt--, update(p);
return;
}
if (a[p].l || a[p].r) {
if (!a[p].r || a[a[p].l].dat > a[a[p].r].dat)
zig(p), remove(a[p].r, val);
else
zag(p), remove(a[p].l, val);
update(p);
}
else p = 0;
return;
}
val < a[p].val ? remove(a[p].l, val) : remove(a[p].r, val);
update(p);
}

int main()
{
build();
scanf("%d", &n);
int op, x;
while (n--) {
scanf("%d%d", &op, &x);
if (op == 1) insert(rt, x);
else if (op == 2) remove(rt, x);
else if (op == 3) printf("%d\n", getrank(rt, x) - 1);
else if (op == 4) printf("%d\n", getval(rt, x + 1));
else if (op == 5) printf("%d\n", getpre(x));
else if (op == 6) printf("%d\n", getnext(x));
}
return 0;
}