Common balancing tree Functions
Insert/delete a number
Find the first K of the successor
Calculate the number greater than or equal to a certain number (the number of reverse orders can be obtained)
Determine the ranking of a number
I have integrated all the functions. The Code is as follows:
# Include <cstdio> # include <cstdlib> const int INF = ~ 0u> 2; # define l ch [x] [0] # define r ch [x] [1] # define kt (CH [CH [RT] [1] [0]) const int maxn = 500010; int Lim; struct splaytree {int SZ [maxn]; int ch [maxn] [2]; int pre [maxn]; int RT, top; inline void up (int x) {SZ [x] = CNT [x] + SZ [l] + SZ [R];} inline void rotate (int x, int F) {int y = pre [X]; ch [y] [! F] = CH [x] [f]; Pre [CH [x] [f] = y; Pre [x] = pre [y]; if (pre [x]) CH [pre [y] [CH [pre [y] [1] = y] = X; ch [x] [f] = y; Pre [y] = x; up (y);} inline void splay (int x, int goal) {// rotate X under goal while (pre [x]! = Goal) {If (pre [pre [x] = goal) rotate (x, CH [pre [x] [0] = X ); else {int y = pre [X], Z = pre [y]; int F = (CH [Z] [0] = y ); if (CH [y] [f] = x) rotate (x ,! F), rotate (x, f); else rotate (Y, F), rotate (x, f) ;}} up (x); If (Goal = 0) RT = x;} inline void RTO (int K, int goal) {// rotate the K-digit to int x = RT under goal; while (SZ [l]! = K-1) {If (k <SZ [l] + 1) x = L; else {k-= (SZ [l] + 1); X = r ;}} splay (x, goal);} inline void vist (int x) {If (x) {printf ("Node % 2D: Left son % 2D right son % 2D VAL: % 2D SZ = % d CNT: % d \ n ", X, L, R, Val [X], SZ [X], CNT [x]); vist (l); vist (r) ;}} void debug () {puts (""); vist (RT); puts ("");} inline void newnode (Int & X, int C, int f) {x = ++ top; L = r = 0; Pre [x] = F; SZ [x] = 1; CNT [x] = 1; Val [x] = C;} inline void Init () {ch [0] [0] = CH [0] [1] = pre [0] = SZ [0] = 0; RT = Top = 0; CNT [0] = 0;} inline void insert (Int & X, int key, int f) {If (! X) {newnode (x, key, f); splay (x, 0); // note that splayreturn;} If (Key = Val [x]) {CNT [x] ++; SZ [x] ++; splay (x, 0); // note the splayreturn ;} else if (Key <Val [x]) {insert (L, key, x) ;}else {insert (R, key, x) ;}up (x );} void del_root () {// Delete the root node int T = RT; If (CH [RT] [1]) {RT = CH [RT] [1]; RTO (1, 0); ch [RT] [0] = CH [T] [0]; If (CH [RT] [0]) pre [CH [RT] [0] = RT;} else RT = CH [RT] [0]; Pre [RT] = 0; up (RT );} void findpre (int x, int key, Int & ANS) {// find the front node if (! X) return; If (Val [x] <= Key) {ans = x; findpre (R, key, ANS);} elsefindpre (L, key, ANS );} void findsucc (int x, int key, Int & ANS) {// find the successor node if (! X) return; If (Val [x]> = Key) {ans = x; findsucc (L, key, ANS);} elsefindsucc (R, key, ANS );} inline int find_kth (int x, int K) {// number of smaller K if (k <SZ [l] + 1) {return find_kth (L, k );} else if (k> SZ [l] + CNT [x]) return find_kth (R, K-SZ [l]-CNT [x]); else {splay (X, 0); Return Val [x] ;}} int find (int x, int key) {If (! X) return 0; else if (Key <Val [x]) return find (L, key); else if (Key> Val [x]) return find (r, key); else return X;} int getmin (int x) {While (l) x = L; return Val [X];} int getmax (int x) {While (r) x = r; return Val [X];} // determines the ranking of the key int getrank (int x, int key, int cur) {// cur: if (Key = Val [x]) return SZ [l] + cur + 1; else if (Key <Val [x]) getrank (L, key, cur); else getrank (R, key, cur + SZ [l] + CNT [RT]);} int get _ Lt (int x, int key) {// number of keys smaller than <: less than if (! X) return 0; If (Val [x]> = Key) return get_lt (L, key); Return CNT [x] + SZ [l] + get_lt (r, key);} int get_mt (int x, int key) {// number of keys greater than MT: More thanif (! X) return 0; If (Val [x] <= Key) return get_mt (R, key); Return CNT [x] + SZ [R] + get_mt (L, key);} void del (Int & X, int f) {// Delete the node where all numbers smaller than Lim are located if (! X) return; If (Val [x]> = lim) {del (L, x);} else {x = r; Pre [x] = F; if (F = 0) RT = x; del (x, f) ;}if (x) Up (x) ;}inline void Update () {del (RT, 0);} int get_mt (INT key) {return get_mt (RT, key);} int get_lt (INT key) {return get_lt (RT, key );} void insert (INT key) {insert (RT, key, 0);} void Delete (INT key) {int node = find (RT, key); splay (node, 0); CNT [RT] --; If (! CNT [RT]) del_root ();} int kth (int K) {return find_kth (RT, k);} int CNT [maxn]; int Val [maxn]; int Lim;} SPT;