Skip to content

Commit

Permalink
Added implementation for Splay Trees. (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vanshika266 authored Apr 8, 2020
1 parent 307b13e commit ffc9cf9
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pydatastructs/trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
BinarySearchTree,
BinaryTreeTraversal,
AVLTree,
BinaryIndexedTree
BinaryIndexedTree,
SplayTree
)
__all__.extend(binary_trees.__all__)

Expand Down
137 changes: 136 additions & 1 deletion pydatastructs/trees/binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
OneDimensionalArray, DynamicOneDimensionalArray)
from pydatastructs.linear_data_structures.arrays import ArrayForTrees
from collections import deque as Queue
from copy import deepcopy

__all__ = [
'AVLTree',
'BinaryTree',
'BinarySearchTree',
'BinaryTreeTraversal',
'BinaryIndexedTree'
'BinaryIndexedTree',
'SplayTree'
]

class BinaryTree(object):
Expand Down Expand Up @@ -754,6 +756,139 @@ def delete(self, key, **kwargs):
self._balance_deletion(a, key)
return True

class SplayTree(SelfBalancingBinaryTree):
"""
Represents Splay Trees.
References
==========
.. [1] https://en.wikipedia.org/wiki/Splay_tree
"""
def _zig(self, x, p):
if self.tree[p].left == x:
super(SplayTree, self)._right_rotate(p, x)
else:
super(SplayTree, self)._left_rotate(p, x)

def _zig_zig(self, x, p):
super(SplayTree, self)._right_rotate(self.tree[p].parent, p)
super(SplayTree, self)._right_rotate(p, x)

def _zig_zag(self, p):
super(SplayTree, self)._left_right_rotate(self.tree[p].parent, p)

def _zag_zag(self, x, p):
super(SplayTree, self)._left_rotate(self.tree[p].parent, p)
super(SplayTree, self)._left_rotate(p, x)

def _zag_zig(self, p):
super(SplayTree, self)._right_left_rotate(self.tree[p].parent, p)

def splay(self, x, p):
while self.tree[x].parent is not None:
if self.tree[p].parent is None:
self._zig(x, p)
elif self.tree[p].left == x and \
self.tree[self.tree[p].parent].left == p:
self._zig_zig(x, p)
elif self.tree[p].right == x and \
self.tree[self.tree[p].parent].right == p:
self._zag_zag(x, p)
elif self.tree[p].left == x and \
self.tree[self.tree[p].parent].right == p:
self._zag_zig(p)
else:
self._zig_zag(p)
p = self.tree[x].parent

def insert(self, key, x):
super(SelfBalancingBinaryTree, self).insert(key, x)
e, p = super(SelfBalancingBinaryTree, self).search(key, parent=True)
self.tree[self.size-1].parent = p
self.splay(e, p)

def delete(self, x):
e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True)
if e is None:
return
self.splay(e, p)
status = super(SelfBalancingBinaryTree, self).delete(x)
return status

def join(self, other):
"""
Joins two trees current and other such that all elements of
the current splay tree are smaller than the elements of the other tree.
Parameters
==========
other: SplayTree
SplayTree which needs to be joined with the self tree.
"""
maxm = self.root_idx
while self.tree[maxm].right is not None:
maxm = self.tree[maxm].right
minm = other.root_idx
while other.tree[minm].left is not None:
minm = other.tree[minm].left
if not self.comparator(self.tree[maxm].key,
other.tree[minm].key):
raise ValueError("Elements of %s aren't less "
"than that of %s"%(self, other))
self.splay(maxm, self.tree[maxm].parent)
idx_update = self.tree._size
for node in other.tree:
if node is not None:
node_copy = TreeNode(node.key, node.data)
if node.left is not None:
node_copy.left = node.left + idx_update
if node.right is not None:
node_copy.right = node.right + idx_update
self.tree.append(node_copy)
else:
self.tree.append(node)
self.tree[self.root_idx].right = \
other.root_idx + idx_update

def split(self, x):
"""
Splits current splay tree into two trees such that one tree contains nodes
with key less than or equal to x and the other tree containing
nodes with key greater than x.
Parameters
==========
x: key
Key of the element on the basis of which split is performed.
Returns
=======
other: SplayTree
SplayTree containing elements with key greater than x.
"""
e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True)
if e is None:
return
self.splay(e, p)
other = SplayTree(None, None)
if self.tree[self.root_idx].right is not None:
traverse = BinaryTreeTraversal(self)
elements = traverse.depth_first_search(order='pre_order', node=self.tree[self.root_idx].right)
for i in range(len(elements)):
super(SelfBalancingBinaryTree, other).insert(elements[i].key, elements[i].data)
for j in range(len(elements) - 1, -1, -1):
e, p = super(SelfBalancingBinaryTree, self).search(elements[j].key, parent=True)
self.tree[e] = None
self.tree[self.root_idx].right = None
return other

class BinaryTreeTraversal(object):
"""
Represents the traversals possible in
Expand Down
36 changes: 35 additions & 1 deletion pydatastructs/trees/tests/test_binary_trees.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydatastructs.trees.binary_trees import (
BinarySearchTree, BinaryTreeTraversal, AVLTree,
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree)
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree, SplayTree)
from pydatastructs.utils.raises_util import raises
from pydatastructs.utils.misc_util import TreeNode
from copy import deepcopy
Expand Down Expand Up @@ -348,3 +348,37 @@ def test_issue_234():
tree.insert(4.56, 4.56)
tree._left_rotate(5, 8)
assert tree.tree[tree.tree[8].parent].left == 8

def test_SplayTree():
t = SplayTree(100, 100)
t.insert(50, 50)
t.insert(200, 200)
t.insert(40, 40)
t.insert(30, 30)
t.insert(20, 20)
t.insert(55, 55)

assert str(t) == ("[(None, 100, 100, None), (None, 50, 50, None), "
"(0, 200, 200, None), (None, 40, 40, 1), (5, 30, 30, 3), "
"(None, 20, 20, None), (4, 55, 55, 2)]")
t.delete(40)
assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), "
"(4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), "
"(None, 55, 55, 2)]")
t.delete(150)
assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), "
"(5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]")

t1 = SplayTree(1000, 1000)
t1.insert(2000, 2000)
assert str(t1) == ("[(None, 1000, 1000, None), (0, 2000, 2000, None)]")

t.join(t1)
assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, 8), (4, 50, 50, None), "
"(5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), (None, 1000, 1000, None), "
"(7, 2000, 2000, None), '']")

s = t.split(200)
assert str(s) == ("[(1, 2000, 2000, None), (None, 1000, 1000, None)]")
assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, None), (4, 50, 50, None), "
"(5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), '', '', '']")

0 comments on commit ffc9cf9

Please sign in to comment.