From ffc9cf932acb0b7d094de8acbfedd50b1fb5d512 Mon Sep 17 00:00:00 2001 From: Vanshika <48067715+Vanshika266@users.noreply.github.com> Date: Wed, 8 Apr 2020 17:11:39 +0530 Subject: [PATCH] Added implementation for Splay Trees. (#157) --- pydatastructs/trees/__init__.py | 3 +- pydatastructs/trees/binary_trees.py | 137 +++++++++++++++++- .../trees/tests/test_binary_trees.py | 36 ++++- 3 files changed, 173 insertions(+), 3 deletions(-) diff --git a/pydatastructs/trees/__init__.py b/pydatastructs/trees/__init__.py index 42f919c35..fa08fe1d5 100644 --- a/pydatastructs/trees/__init__.py +++ b/pydatastructs/trees/__init__.py @@ -12,7 +12,8 @@ BinarySearchTree, BinaryTreeTraversal, AVLTree, - BinaryIndexedTree + BinaryIndexedTree, + SplayTree ) __all__.extend(binary_trees.__all__) diff --git a/pydatastructs/trees/binary_trees.py b/pydatastructs/trees/binary_trees.py index 6e329cc71..33f368534 100644 --- a/pydatastructs/trees/binary_trees.py +++ b/pydatastructs/trees/binary_trees.py @@ -4,13 +4,15 @@ OneDimensionalArray, DynamicOneDimensionalArray) from pydatastructs.linear_data_structures.arrays import ArrayForTrees from collections import deque as Queue +from copy import deepcopy __all__ = [ 'AVLTree', 'BinaryTree', 'BinarySearchTree', 'BinaryTreeTraversal', - 'BinaryIndexedTree' + 'BinaryIndexedTree', + 'SplayTree' ] class BinaryTree(object): @@ -754,6 +756,139 @@ def delete(self, key, **kwargs): self._balance_deletion(a, key) return True +class SplayTree(SelfBalancingBinaryTree): + """ + Represents Splay Trees. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Splay_tree + + """ + def _zig(self, x, p): + if self.tree[p].left == x: + super(SplayTree, self)._right_rotate(p, x) + else: + super(SplayTree, self)._left_rotate(p, x) + + def _zig_zig(self, x, p): + super(SplayTree, self)._right_rotate(self.tree[p].parent, p) + super(SplayTree, self)._right_rotate(p, x) + + def _zig_zag(self, p): + super(SplayTree, self)._left_right_rotate(self.tree[p].parent, p) + + def _zag_zag(self, x, p): + super(SplayTree, self)._left_rotate(self.tree[p].parent, p) + super(SplayTree, self)._left_rotate(p, x) + + def _zag_zig(self, p): + super(SplayTree, self)._right_left_rotate(self.tree[p].parent, p) + + def splay(self, x, p): + while self.tree[x].parent is not None: + if self.tree[p].parent is None: + self._zig(x, p) + elif self.tree[p].left == x and \ + self.tree[self.tree[p].parent].left == p: + self._zig_zig(x, p) + elif self.tree[p].right == x and \ + self.tree[self.tree[p].parent].right == p: + self._zag_zag(x, p) + elif self.tree[p].left == x and \ + self.tree[self.tree[p].parent].right == p: + self._zag_zig(p) + else: + self._zig_zag(p) + p = self.tree[x].parent + + def insert(self, key, x): + super(SelfBalancingBinaryTree, self).insert(key, x) + e, p = super(SelfBalancingBinaryTree, self).search(key, parent=True) + self.tree[self.size-1].parent = p + self.splay(e, p) + + def delete(self, x): + e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True) + if e is None: + return + self.splay(e, p) + status = super(SelfBalancingBinaryTree, self).delete(x) + return status + + def join(self, other): + """ + Joins two trees current and other such that all elements of + the current splay tree are smaller than the elements of the other tree. + + Parameters + ========== + + other: SplayTree + SplayTree which needs to be joined with the self tree. + + """ + maxm = self.root_idx + while self.tree[maxm].right is not None: + maxm = self.tree[maxm].right + minm = other.root_idx + while other.tree[minm].left is not None: + minm = other.tree[minm].left + if not self.comparator(self.tree[maxm].key, + other.tree[minm].key): + raise ValueError("Elements of %s aren't less " + "than that of %s"%(self, other)) + self.splay(maxm, self.tree[maxm].parent) + idx_update = self.tree._size + for node in other.tree: + if node is not None: + node_copy = TreeNode(node.key, node.data) + if node.left is not None: + node_copy.left = node.left + idx_update + if node.right is not None: + node_copy.right = node.right + idx_update + self.tree.append(node_copy) + else: + self.tree.append(node) + self.tree[self.root_idx].right = \ + other.root_idx + idx_update + + def split(self, x): + """ + Splits current splay tree into two trees such that one tree contains nodes + with key less than or equal to x and the other tree containing + nodes with key greater than x. + + Parameters + ========== + + x: key + Key of the element on the basis of which split is performed. + + Returns + ======= + + other: SplayTree + SplayTree containing elements with key greater than x. + + """ + e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True) + if e is None: + return + self.splay(e, p) + other = SplayTree(None, None) + if self.tree[self.root_idx].right is not None: + traverse = BinaryTreeTraversal(self) + elements = traverse.depth_first_search(order='pre_order', node=self.tree[self.root_idx].right) + for i in range(len(elements)): + super(SelfBalancingBinaryTree, other).insert(elements[i].key, elements[i].data) + for j in range(len(elements) - 1, -1, -1): + e, p = super(SelfBalancingBinaryTree, self).search(elements[j].key, parent=True) + self.tree[e] = None + self.tree[self.root_idx].right = None + return other + class BinaryTreeTraversal(object): """ Represents the traversals possible in diff --git a/pydatastructs/trees/tests/test_binary_trees.py b/pydatastructs/trees/tests/test_binary_trees.py index b516895e4..00be1ac35 100644 --- a/pydatastructs/trees/tests/test_binary_trees.py +++ b/pydatastructs/trees/tests/test_binary_trees.py @@ -1,6 +1,6 @@ from pydatastructs.trees.binary_trees import ( BinarySearchTree, BinaryTreeTraversal, AVLTree, - ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree) + ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree, SplayTree) from pydatastructs.utils.raises_util import raises from pydatastructs.utils.misc_util import TreeNode from copy import deepcopy @@ -348,3 +348,37 @@ def test_issue_234(): tree.insert(4.56, 4.56) tree._left_rotate(5, 8) assert tree.tree[tree.tree[8].parent].left == 8 + +def test_SplayTree(): + t = SplayTree(100, 100) + t.insert(50, 50) + t.insert(200, 200) + t.insert(40, 40) + t.insert(30, 30) + t.insert(20, 20) + t.insert(55, 55) + + assert str(t) == ("[(None, 100, 100, None), (None, 50, 50, None), " + "(0, 200, 200, None), (None, 40, 40, 1), (5, 30, 30, 3), " + "(None, 20, 20, None), (4, 55, 55, 2)]") + t.delete(40) + assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), " + "(4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), " + "(None, 55, 55, 2)]") + t.delete(150) + assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), " + "(5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]") + + t1 = SplayTree(1000, 1000) + t1.insert(2000, 2000) + assert str(t1) == ("[(None, 1000, 1000, None), (0, 2000, 2000, None)]") + + t.join(t1) + assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, 8), (4, 50, 50, None), " + "(5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), (None, 1000, 1000, None), " + "(7, 2000, 2000, None), '']") + + s = t.split(200) + assert str(s) == ("[(1, 2000, 2000, None), (None, 1000, 1000, None)]") + assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, None), (4, 50, 50, None), " + "(5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), '', '', '']")