splay基础

splay简介

splay树也是一种二叉平衡树,同时为了保证始终可以对查询频率较高的节点优先查询,splay树会对于每一个操作的节点,均将该节点旋转到树根。据证明,按照splay的旋转方式,可以保证对于大部分数据均可以达到logn的复杂度

splay树的旋转

我们可以分两种情况

1.

2.

其实是四种情况,以上两种分别对称为另外两种情况,但旋转顺序相同

旋转前后图的中序遍历是不变的,且可以保证复杂度为logn

实现如下

splay树的插入

可以写做splay(x,k),将x点插入到k点之后

我们可以找到k的后继k+1,将x点插入k+1之后。(中序遍历)

我们可以先将k转到根节点,然后将k+1转到k的下面。

因为旋转后中序遍历不变,所以k+1依旧是k的后继,所以k的左子树必然为空,我们将x点放到k的左子树即可。

splay树的删除

假设我们要删除l和r之间的序列,我们可以先找到l-1和r+1,我们将l-1转到根节点,然后将r+1转到l-1的下面,然后将r+1的左子树变为空,深入了解中序遍历,将会很容易理解

维护一个splay树的信息

可以使用类似线段树的懒标记

具体可以看模板例题实现

题目链接

关于splay树的主体

 struct Node {
     int s[2], p, v;//s为子树,p为父亲,v为树所代表的值
     int size, flag;//size为子树数目,flag为懒标记
 
     void init(int vv, int pp) {
         v = vv, p = pp;
         size = 1;
     }
 } tr[maxn];

pushdown与pushup

pushup获取子树数目属于维护信息,pushdown将懒标记下传

pushup放在旋转结束来维护信息

pushdown放在递归寻找遍历之前

 void pushup(int x) {
     tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
 }
 
 void pushdown(int x) {
     if (tr[x].flag) {
         swap(tr[x].s[0], tr[x].s[1]);
         tr[tr[x].s[0]].flag ^= 1;
         tr[tr[x].s[1]].flag ^= 1;
         tr[x].flag = 0;
     }
 }

旋转实现如下

 void rotate(int x) {
     int y = tr[x].p, z = tr[y].p;
     int k = tr[y].s[1] == x;
     tr[z].s[tr[z].s[1] == y] = x;
     tr[x].p = z;
     tr[y].s[k] = tr[x].s[k ^ 1];
     tr[tr[x].s[k ^ 1]].p = y;
     tr[x].s[k ^ 1] = y;
     tr[y].p = x;
     pushup(y);
     pushup(x);
 }
 

递归时

 int get_k(int k) {
     int u = root;
     while (true) {
         pushdown(u);
         if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
         else if (tr[tr[u].s[0]].size + 1 == k) return u;
         else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
     }
 }
 
 void output(int u) {
     pushdown(u);
     if (tr[u].s[0]) output(tr[u].s[0]);
     if (tr[u].v >= 1 && tr[u].v <= n) printf("%d ", tr[u].v);
     if (tr[u].s[1]) output(tr[u].s[1]);
 }

主体代码实现

 void splay(int x, int k) {
     while (tr[x].p != k) {
         int y = tr[x].p, z = tr[y].p;
         if (z != k) {
             if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
             else rotate(y);
         }
         rotate(x);
     }
     if (!k) root = x;
 }
 
 void insert(int v) {
     int u = root, p = 0;
     while (u) p = u, u = tr[u].s[v > tr[u].v];
     u = ++idx;
     if (p) tr[p].s[v > tr[p].v] = u;
     tr[u].init(v, p);
     splay(u, 0);
 }

ac代码:

 #include <cstdio>
 #include <iostream>
 #include <cmath>
 #include <cstdlib>
 #include <algorithm>
 #include <cstring>
 #include <string>
 #include <vector>
 #include <list>
 #include <map>
 #include <unordered_map>
 #include <queue>
 #include <set>
 #include <deque>
 #include <list>
 #include <stack>
 
 #define ll long long
 using namespace std;
 ll mod = 1e9 + 7;
 
 const ll INF = 1e18;
 const ll maxn = 1e5 + 5;
 const double pi = 3.141592653;
 
 ll gcd(ll a, ll b) {
     while (b ^= a ^= b ^= a %= b);
     return a;
 }
 
 int n, m;
 
 struct Node {
     int s[2], p, v;//s为子树,p为父亲,v为树所代表的值
     int size, flag;//size为子树数目,flag为懒标记
 
     void init(int vv, int pp) {
         v = vv, p = pp;
         size = 1;
     }
 } tr[maxn];
 
 int root, idx;
 
 void pushup(int x) {
     tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
 }
 
 void pushdown(int x) {
     if (tr[x].flag) {
         swap(tr[x].s[0], tr[x].s[1]);
         tr[tr[x].s[0]].flag ^= 1;
         tr[tr[x].s[1]].flag ^= 1;
         tr[x].flag = 0;
     }
 }
 
 void rotate(int x) {
     int y = tr[x].p, z = tr[y].p;
     int k = tr[y].s[1] == x;
     tr[z].s[tr[z].s[1] == y] = x;
     tr[x].p = z;
     tr[y].s[k] = tr[x].s[k ^ 1];
     tr[tr[x].s[k ^ 1]].p = y;
     tr[x].s[k ^ 1] = y;
     tr[y].p = x;
     pushup(y);
     pushup(x);
 }
 
 void splay(int x, int k) {
     while (tr[x].p != k) {
         int y = tr[x].p, z = tr[y].p;
         if (z != k) {
             if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
             else rotate(y);
         }
         rotate(x);
     }
     if (!k) root = x;
 }
 
 void insert(int v) {
     int u = root, p = 0;
     while (u) p = u, u = tr[u].s[v > tr[u].v];
     u = ++idx;
     if (p) tr[p].s[v > tr[p].v] = u;
     tr[u].init(v, p);
     splay(u, 0);
 }
 
 int get_k(int k) {
     int u = root;
     while (true) {
         pushdown(u);
         if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
         else if (tr[tr[u].s[0]].size + 1 == k) return u;
         else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
     }
 }
 
 void output(int u) {
     pushdown(u);
     if (tr[u].s[0]) output(tr[u].s[0]);
     if (tr[u].v >= 1 && tr[u].v <= n) printf("%d ", tr[u].v);
     if (tr[u].s[1]) output(tr[u].s[1]);
 }
 
 void solve() {
     scanf("%d%d", &n, &m);
     for (int i = 0; i <= n + 1; i++) insert(i);
     while (m--) {
         int l, r;
         scanf("%d%d", &l, &r);
         l = get_k(l), r = get_k(r + 2);
         splay(l, 0);
         splay(r, l);
         tr[tr[r].s[0]].flag ^= 1;
 
     }
     output(root);
 }
 
 
 int main() {
 #ifdef ONLINE_JUDGE
 #else
     freopen("in.txt", "r", stdin);
     //freopen("out.txt", "w", stdout);
 #endif
     int T = 1;
     //scanf("%d", &T);
     while (T--) solve();
 }

《splay基础》有1条评论

发表评论