class Node:
def __init__(self, key):
self.key = key
self.left = None
self.right = None
self.height = 1
class AVLTree:
def __init__(self):
self.root = None
self.size = 0
def _height(self, node):
return node.height if node else 0
def _balance_factor(self, node):
if not node:
return 0
return self._height(node.left) - self._height(node.right)
def _update_height(self, node):
if node:
node.height = 1 + max(self._height(node.left), self._height(node.right))
def _rotate_right(self, y):
x = y.left
B = x.right
x.right = y
y.left = B
self._update_height(y)
self._update_height(x)
return x
def _rotate_left(self, x):
y = x.right
B = y.left
y.left = x
x.right = B
self._update_height(x)
self._update_height(y)
return y
def insert(self, key):
self.root = self._insert_recursive(self.root, key)
self.size += 1
def _insert_recursive(self, node, key):
if not node:
return Node(key)
if key < node.key:
node.left = self._insert_recursive(node.left, key)
else:
node.right = self._insert_recursive(node.right, key)
self._update_height(node)
balance = self._balance_factor(node)
if balance > 1 and key < node.left.key:
return self._rotate_right(node)
if balance < -1 and key > node.right.key:
return self._rotate_left(node)
if balance > 1 and key > node.left.key:
node.left = self._rotate_left(node.left)
return self._rotate_right(node)
if balance < -1 and key < node.right.key:
node.right = self._rotate_right(node.right)
return self._rotate_left(node)
return node
def delete(self, key):
self.root = self._delete_recursive(self.root, key)
self.size -= 1
def _delete_recursive(self, node, key):
if not node:
return None
if key < node.key:
node.left = self._delete_recursive(node.left, key)
elif key > node.key:
node.right = self._delete_recursive(node.right, key)
else:
if not node.left:
return node.right
elif not node.right:
return node.left
min_node = self._find_min(node.right)
node.key = min_node.key
node.right = self._delete_recursive(node.right, min_node.key)
self._update_height(node)
balance = self._balance_factor(node)
if balance > 1 and self._balance_factor(node.left) >= 0:
return self._rotate_right(node)
if balance > 1 and self._balance_factor(node.left) < 0:
node.left = self._rotate_left(node.left)
return self._rotate_right(node)
if balance < -1 and self._balance_factor(node.right) <= 0:
return self._rotate_left(node)
if balance < -1 and self._balance_factor(node.right) > 0:
node.right = self._rotate_right(node.right)
return self._rotate_left(node)
return node
def _find_min(self, node):
while node.left:
node = node.left
return node