Deleting node in BST Python
Solution 1:
node structure and insertion
We start with a simple node
structure, but notice the left
and right
properties can be set at the time of construction -
# btree.py
class node:
def __init__(self, data, left=None, right=None):
self.data = data
self.left = left
self.right = right
Recursion is a functional heritage and so using it with functional style yields the best results. This means avoiding things like mutation, variable reassignment, and other side effects. Notice how add
always constructs a new node rather than mutating an old one. This is the reason we designed node
to accept all properties at time of construction -
# btree.py (continued)
def add(t, q):
if not t:
return node(q)
elif q < t.data:
return node(t.data, add(t.left, q), t.right)
elif q > t.data:
return node(t.data, t.left, add(t.right, q))
else:
return node(q, t.left, t.right)
inorder traversal and string conversion
After we add
some nodes, we need a way to visualize the tree. Below we write an inorder
traversal and a to_str
function -
# btree.py (continued)
def inorder(t):
if not t: return
yield from inorder(t.left)
yield t.data
yield from inorder(t.right)
def to_str(t):
return "->".join(map(str,inorder(t)))
btree object interface
Notice we did not over-complicate our plain functions above by entangling them with a class. We can now define a btree
object-oriented interface which simply wraps the plain functions -
# btree.py (continued)
class btree:
def __init__(self, t=None): self.t = t
def __str__(self): return to_str(self.t)
def add(self, q): return btree(add(self.t, q))
def inorder(self): return inorder(self.t)
Notice we also wrote btree.py
as its own module. This defines a barrier of abstraction and allows us to expand, modify, and reuse features without tangling them with other areas of your program. Let's see how our tree works so far -
# main.py
from btree import btree
t = btree().add(50).add(60).add(40).add(30).add(45).add(55).add(100)
print(str(t))
# 30->40->45->50->55->60->100
minimum and maximum
We'll continue working like this, defining plain function that work directly on node
objects. Next up, minimum
and maximum
-
# btree.py (continued)
from math import inf
def minimum(t, r=inf):
if not t:
return r
elif t.data < r:
return min(minimum(t.left, t.data), minimum(t.right, t.data))
else:
return min(minimum(t.left, r), minimum(t.right, r))
def maximum(t, r=-inf):
if not t:
return r
elif t.data > r:
return max(maximum(t.left, t.data), maximum(t.right, t.data))
else:
return max(maximum(t.left, r), maximum(t.right, r))
The btree
interface provides only a wrapper of our plain functions -
# btree.py (continued)
class btree:
def __init__(): # ...
def __str__(): # ...
def add(): # ...
def inorder(): # ...
def maximum(self): return maximum(self.t)
def minimum(self): return minimum(self.t)
We can test minimum
and maximum
now -
# main.py
from btree import btree
t = btree().add(50).add(60).add(40).add(30).add(45).add(55).add(100)
print(str(t))
# 30->40->45->50->55->60->100
print(t.minimum(), t.maximum()) # <-
# 30 100
insert from iterable
Chaining .add().add().add()
is a bit verbose. Providing an add_iter
function allows us to insert any number of values from another iterable. We introduce it now because we're about to reuse it in the upcoming remove
function too -
def add_iter(t, it):
for q in it:
t = add(t, q)
return t
#main.py
from btree import btree
t = btree().add_iter([50, 60, 40, 30, 45, 55, 100]) # <-
print(str(t))
# 30->40->45->50->55->60->100
print(t.minimum(), t.maximum())
# 30 100
node removal and preorder traversal
We now move onto the remove
function. It works similarly to the add
function, performing a t.data
comparison with the value to remove, q
. You'll notice we use add_iter
here to combine the left
and right
branches of the node to be deleted. We could reuse inorder
iterator for our tree here, but preorder
will keep the tree a bit more balanced. That's a different topic entirely, so we won't get into that now -
# btree.py (continued)
def remove(t, q):
if not t:
return t
elif q < t.data:
return node(t.data, remove(t.left, q), t.right)
elif q > t.data:
return node(t.data, t.left, remove(t.right, q))
else:
return add_iter(t.left, preorder(t.right))
def preorder(t):
if not t: return
yield t.data
yield from preorder(t.left)
yield from preorder(t.right)
Don't forget to extend the btree
interface -
# btree.py (continued)
class btree:
def __init__(): # ...
def __str__(): # ...
def add(): # ...
def inorder(): # ...
def maximum(): # ...
def minimum(): # ...
def add_iter(self, it): return btree(add_iter(self.t, it))
def remove(self, q): return btree(remove(self.t, q))
def preorder(self): return preorder(self.t)
Let's see remove
in action now -
# main.py
from btree import btree
t = btree().add_iter([50, 60, 40, 30, 45, 55, 100])
print(str(t))
# 30->40->45->50->55->60->100
print(t.minimum(), t.maximum())
# 30 100
t = t.remove(30).remove(50).remove(100) # <-
print(str(t))
# 40->45->55->60
print(t.minimum(), t.maximum())
# 40 60
completed btree module
Here's the completed module we built over the course of this answer -
from math import inf
class node:
def __init__(self, data, left=None, right=None):
self.data = data
self.left = left
self.right = right
def add(t, q):
if not t:
return node(q)
elif q < t.data:
return node(t.data, add(t.left, q), t.right)
elif q > t.data:
return node(t.data, t.left, add(t.right, q))
else:
return node(q, t.left, t.right)
def add_iter(t, it):
for q in it:
t = add(t, q)
return t
def maximum(t, r=-inf):
if not t:
return r
elif t.data > r:
return max(maximum(t.left, t.data), maximum(t.right, t.data))
else:
return max(maximum(t.left, r), maximum(t.right, r))
def minimum(t, r=inf):
if not t:
return r
elif t.data < r:
return min(minimum(t.left, t.data), minimum(t.right, t.data))
else:
return min(minimum(t.left, r), minimum(t.right, r))
def inorder(t):
if not t: return
yield from inorder(t.left)
yield t.data
yield from inorder(t.right)
def preorder(t):
if not t: return
yield t.data
yield from preorder(t.left)
yield from preorder(t.right)
def remove(t, q):
if not t:
return t
elif q < t.data:
return node(t.data, remove(t.left, q), t.right)
elif q > t.data:
return node(t.data, t.left, remove(t.right, q))
else:
return add_iter(t.left, preorder(t.right))
def to_str(t):
return "->".join(map(str,inorder(t)))
class btree:
def __init__(self, t=None): self.t = t
def __str__(self): return to_str(self.t)
def add(self, q): return btree(add(self.t, q))
def add_iter(self, it): return btree(add_iter(self.t, it))
def maximum(self): return maximum(self.t)
def minimum(self): return minimum(self.t)
def inorder(self): return inorder(self.t)
def preorder(self): return preorder(self.t)
def remove(self, q): return btree(remove(self.t, q))
have your cake and eat it too
One understated advantage of the approach above is that we have a dual interface for our btree
module. We can use it in the traditional object-oriented way as demonstrated, or we can use it using a more functional approach -
# main.py
from btree import add_iter, remove, to_str, minimum, maximum
t = add_iter(None, [50, 60, 40, 30, 45, 55, 100])
print(to_str(t))
# 30->40->45->50->55->60->100
print(minimum(t), maximum(t))
# 30 100
t = remove(remove(remove(t, 30), 50), 100)
print(to_str(t))
# 40->45->55->60
print(minimum(t), maximum(t))
# 40 60
additional reading
I've written extensively about the techniques used in this answer. Follow the links to see them used in other contexts with additional explanation provided -
-
I want to reverse the stack but i dont know how to use recursion for reversing this… How can i reverse the stack without using Recursion
-
Finding all maze solutions with Python
-
Return middle node of linked list with recursion
-
How do i recursively find a size of subtree based on any given node? (BST)
Solution 2:
Instead of deleting the node using del
you can reassign the parent node and let the child node be collected by garbage collector.
In deleteNode
return the new child instead of deleting the node. Assign the returned value to the parent.
def deleteNode(self, currNode, value):
if not currNode:
return currNode
elif value < currNode.data:
return deleteNode(self, currNode.left, value)
elif value > currNode.data:
return deleteNode(self, currNode.right, value)
else:
if not currNode.right:
return currNode.left
if not currNode.left:
return currNode.right
temp_val = currNode.right
mini_val = temp_val.val
while temp_val.left:
temp_val = temp_val.left
mini_val = temp_val.val
currNode.right = deleteNode(currNode.right,currNode.val)
return currNode