Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 62 additions & 239 deletions data_structures/binary_tree/avl_tree.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
"""
Implementation of an auto-balanced binary tree!
For doctests run following command:
python3 -m doctest -v avl_tree.py
For testing run:
python avl_tree.py
"""

from __future__ import annotations

import math

Check failure on line 3 in data_structures/binary_tree/avl_tree.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

data_structures/binary_tree/avl_tree.py:3:8: F401 `math` imported but unused help: Remove unused import: `math`

Check failure on line 3 in data_structures/binary_tree/avl_tree.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

data_structures/binary_tree/avl_tree.py:3:8: F401 `math` imported but unused help: Remove unused import: `math`
import random

Check failure on line 4 in data_structures/binary_tree/avl_tree.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

data_structures/binary_tree/avl_tree.py:4:8: F401 `random` imported but unused help: Remove unused import: `random`

Check failure on line 4 in data_structures/binary_tree/avl_tree.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

data_structures/binary_tree/avl_tree.py:4:8: F401 `random` imported but unused help: Remove unused import: `random`
from typing import Any

Expand All @@ -24,21 +16,13 @@

def push(self, data: Any) -> None:
self.data.append(data)
self.tail = self.tail + 1
self.tail += 1

def pop(self) -> Any:
ret = self.data[self.head]
self.head = self.head + 1
self.head += 1
return ret

def count(self) -> int:
return self.tail - self.head

def print_queue(self) -> None:
print(self.data)
print("**************")
print(self.data[self.head : self.tail])


class MyNode:
def __init__(self, data: Any) -> None:
Expand All @@ -59,9 +43,6 @@
def get_height(self) -> int:
return self.height

def set_data(self, data: Any) -> None:
self.data = data

def set_left(self, node: MyNode | None) -> None:
self.left = node

Expand All @@ -73,277 +54,119 @@


def get_height(node: MyNode | None) -> int:
if node is None:
return 0
return node.get_height()
return node.get_height() if node else 0


def my_max(a: int, b: int) -> int:
if a > b:
return a
return b
return a if a > b else b


def right_rotation(node: MyNode) -> MyNode:
r"""
A B
/ \ / \
B C Bl A
/ \ --> / / \
Bl Br UB Br C
/
UB
UB = unbalanced node
"""
print("left rotation node:", node.get_data())
ret = node.get_left()
assert ret is not None
node.set_left(ret.get_right())
ret.set_right(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
node.set_height(h1)
h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1
ret.set_height(h2)

node.set_height(
my_max(get_height(node.get_left()), get_height(node.get_right())) + 1
)
ret.set_height(my_max(get_height(ret.get_left()), get_height(ret.get_right())) + 1)

return ret


def left_rotation(node: MyNode) -> MyNode:
"""
a mirror symmetry rotation of the left_rotation
"""
print("right rotation node:", node.get_data())
ret = node.get_right()
assert ret is not None
node.set_right(ret.get_left())
ret.set_left(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
node.set_height(h1)
h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1
ret.set_height(h2)

node.set_height(
my_max(get_height(node.get_left()), get_height(node.get_right())) + 1
)
ret.set_height(my_max(get_height(ret.get_left()), get_height(ret.get_right())) + 1)

return ret


def lr_rotation(node: MyNode) -> MyNode:
r"""
A A Br
/ \ / \ / \
B C LR Br C RR B A
/ \ --> / \ --> / / \
Bl Br B UB Bl UB C
\ /
UB Bl
RR = right_rotation LR = left_rotation
"""
left_child = node.get_left()
assert left_child is not None
node.set_left(left_rotation(left_child))
node.set_left(left_rotation(node.get_left()))
return right_rotation(node)


def rl_rotation(node: MyNode) -> MyNode:
right_child = node.get_right()
assert right_child is not None
node.set_right(right_rotation(right_child))
node.set_right(right_rotation(node.get_right()))
return left_rotation(node)


def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
def insert_node(node: MyNode | None, data: Any) -> MyNode:
if node is None:
return MyNode(data)

if data < node.get_data():
node.set_left(insert_node(node.get_left(), data))
if (
get_height(node.get_left()) - get_height(node.get_right()) == 2
): # an unbalance detected

if get_height(node.get_left()) - get_height(node.get_right()) == 2:
left_child = node.get_left()
assert left_child is not None
if (
data < left_child.get_data()
): # new node is the left child of the left child
if data < left_child.get_data():
node = right_rotation(node)
else:
node = lr_rotation(node)

else:
node.set_right(insert_node(node.get_right(), data))

if get_height(node.get_right()) - get_height(node.get_left()) == 2:
right_child = node.get_right()
assert right_child is not None
if data < right_child.get_data():
node = rl_rotation(node)
else:
node = left_rotation(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
node.set_height(h1)
return node


def get_right_most(root: MyNode) -> Any:
while True:
right_child = root.get_right()
if right_child is None:
break
root = right_child
return root.get_data()


def get_left_most(root: MyNode) -> Any:
while True:
left_child = root.get_left()
if left_child is None:
break
root = left_child
return root.get_data()


def del_node(root: MyNode, data: Any) -> MyNode | None:
left_child = root.get_left()
right_child = root.get_right()
if root.get_data() == data:
if left_child is not None and right_child is not None:
temp_data = get_left_most(right_child)
root.set_data(temp_data)
root.set_right(del_node(right_child, temp_data))
elif left_child is not None:
root = left_child
elif right_child is not None:
root = right_child
else:
return None
elif root.get_data() > data:
if left_child is None:
print("No such data")
return root
else:
root.set_left(del_node(left_child, data))
# root.get_data() < data
elif right_child is None:
return root
else:
root.set_right(del_node(right_child, data))

# Re-fetch left_child and right_child references
left_child = root.get_left()
right_child = root.get_right()

if get_height(right_child) - get_height(left_child) == 2:
assert right_child is not None
if get_height(right_child.get_right()) > get_height(right_child.get_left()):
root = left_rotation(root)
else:
root = rl_rotation(root)
elif get_height(right_child) - get_height(left_child) == -2:
assert left_child is not None
if get_height(left_child.get_left()) > get_height(left_child.get_right()):
root = right_rotation(root)
else:
root = lr_rotation(root)
height = my_max(get_height(root.get_right()), get_height(root.get_left())) + 1
root.set_height(height)
return root
node.set_height(
my_max(get_height(node.get_left()), get_height(node.get_right())) + 1
)
return node


class AVLtree:
"""
An AVL tree doctest
Examples:
>>> t = AVLtree()
>>> t.insert(4)
insert:4
>>> print(str(t).replace(" \\n","\\n"))
4
*************************************
>>> t.insert(2)
insert:2
>>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n"))
4
2 *
*************************************
>>> t.insert(3)
insert:3
right rotation node: 2
left rotation node: 4
>>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n"))
3
2 4
*************************************
>>> t.get_height()
2
>>> t.del_node(3)
delete:3
>>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n"))
4
2 *
*************************************
"""

def __init__(self) -> None:
self.root: MyNode | None = None

def insert(self, data: Any) -> None:
self.root = insert_node(self.root, data)

def get_height(self) -> int:
return get_height(self.root)

def insert(self, data: Any) -> None:
print("insert:" + str(data))
self.root = insert_node(self.root, data)

def del_node(self, data: Any) -> None:
print("delete:" + str(data))
if self.root is None:
print("Tree is empty!")
return
self.root = del_node(self.root, data)

def __str__(
self,
) -> str: # a level traversale, gives a more intuitive look on the tree
output = ""
q = MyQueue()
q.push(self.root)
layer = self.get_height()
if layer == 0:
return output
cnt = 0
while not q.is_empty():
node = q.pop()
space = " " * int(math.pow(2, layer - 1))
output += space
if node is None:
output += "*"
q.push(None)
q.push(None)
else:
output += str(node.get_data())
q.push(node.get_left())
q.push(node.get_right())
output += space
cnt = cnt + 1
for i in range(100):
if cnt == math.pow(2, i) - 1:
layer = layer - 1
if layer == 0:
output += "\n*************************************"
return output
output += "\n"
break
output += "\n*************************************"
return output


def _test() -> None:
import doctest

doctest.testmod()


if __name__ == "__main__":
_test()
t = AVLtree()
lst = list(range(10))
random.shuffle(lst)
for i in lst:
t.insert(i)
print(str(t))
random.shuffle(lst)
for i in lst:
t.del_node(i)
print(str(t))
# ✅ NEW FEATURE (YOUR CONTRIBUTION)


def inorder_traversal(root, result):
"""
Performs inorder traversal of AVL tree and stores result.
"""
if root:
inorder_traversal(root.get_left(), result)
result.append(root.get_data())
inorder_traversal(root.get_right(), result)


def avl_sort(arr):
"""
Sorts a list using AVL Tree.

Example:
>>> avl_sort([3,1,2])
[1, 2, 3]
"""
tree = AVLtree()

for value in arr:
tree.insert(value)

result = []
inorder_traversal(tree.root, result)

return result
Loading