Binary Search Tree - Java Implementation

I'm writing a program that utilizes a binary search tree to store data. In a previous program (unrelated), I was able to implement a linked list using an implementation provided with Java SE6. Is there something similar for a binary search tree, or will I need to "start from scratch"?


You can use a TreeMap data structure. TreeMap is implemented as a red black tree, which is a self-balancing binary search tree.


According to Collections Framework Overview you have two balanced tree implementations:

  • TreeSet
  • TreeMap

Here is my simple binary search tree implementation in Java SE 1.8:

public class BSTNode
{
    int data;
    BSTNode parent;
    BSTNode left;
    BSTNode right;

    public BSTNode(int data)
    {
        this.data = data;
        this.left = null;
        this.right = null;
        this.parent = null;
    }

    public BSTNode()
    {
    }
}

public class BSTFunctions
{
    BSTNode ROOT;

    public BSTFunctions()
    {
        this.ROOT = null;
    }

    void insertNode(BSTNode node, int data)
    {
        if (node == null)
        {
            node = new BSTNode(data);
            ROOT = node;
        }
        else if (data < node.data && node.left == null)
        {
            node.left = new BSTNode(data);
            node.left.parent = node;
        }
        else if (data >= node.data && node.right == null)
        {
            node.right = new BSTNode(data);
            node.right.parent = node;
        }
        else
        {
            if (data < node.data)
            {
                insertNode(node.left, data);
            }
            else
            {
                insertNode(node.right, data);
            }
        }
    }

    public boolean search(BSTNode node, int data)
    {
        if (node == null)
        {
            return false;
        }
        else if (node.data == data)
        {
            return true;
        }
        else
        {
            if (data < node.data)
            {
                return search(node.left, data);
            }
            else
            {
                return search(node.right, data);
            }
        }
    }

    public void printInOrder(BSTNode node)
    {
        if (node != null)
        {
            printInOrder(node.left);
            System.out.print(node.data + " - ");
            printInOrder(node.right);
        }
    }

    public void printPostOrder(BSTNode node)
    {
        if (node != null)
        {
            printPostOrder(node.left);
            printPostOrder(node.right);
            System.out.print(node.data + " - ");
        }
    }

    public void printPreOrder(BSTNode node)
    {
        if (node != null)
        {
            System.out.print(node.data + " - ");
            printPreOrder(node.left);
            printPreOrder(node.right);
        }
    }

    public static void main(String[] args)
    {
        BSTFunctions f = new BSTFunctions();
        /**
         * Insert
         */
        f.insertNode(f.ROOT, 20);
        f.insertNode(f.ROOT, 5);
        f.insertNode(f.ROOT, 25);
        f.insertNode(f.ROOT, 3);
        f.insertNode(f.ROOT, 7);
        f.insertNode(f.ROOT, 27);
        f.insertNode(f.ROOT, 24);

        /**
         * Print
         */
        f.printInOrder(f.ROOT);
        System.out.println("");
        f.printPostOrder(f.ROOT);
        System.out.println("");
        f.printPreOrder(f.ROOT);
        System.out.println("");

        /**
         * Search
         */
        System.out.println(f.search(f.ROOT, 27) ? "Found" : "Not Found");
        System.out.println(f.search(f.ROOT, 10) ? "Found" : "Not Found");
    }
}

And the output is:

3 - 5 - 7 - 20 - 24 - 25 - 27 - 
3 - 7 - 5 - 24 - 27 - 25 - 20 - 
20 - 5 - 3 - 7 - 25 - 24 - 27 - 
Found
Not Found

Here is a sample implementation:

import java.util.*;

public class MyBSTree<K,V> implements MyTree<K,V>{
    private BSTNode<K,V> _root;
    private int _size;
    private Comparator<K> _comparator;
    private int mod = 0;

    public MyBSTree(Comparator<K> comparator){
        _comparator = comparator;
    }

    public Node<K,V> root(){
        return _root;
    }

    public int size(){
        return _size;
    }

    public boolean containsKey(K key){
        if(_root == null){
            return false;
        }

        BSTNode<K,V> node = _root;

        while (node != null){
            int comparison = compare(key, node.key());

            if(comparison == 0){
                return true;
            }else if(comparison <= 0){
                node = node._left;
            }else {
                node = node._right;
            }
        }

        return false;
    }

    private int compare(K k1, K k2){
        if(_comparator != null){
            return _comparator.compare(k1,k2);
        }
        else {
            Comparable<K> comparable = (Comparable<K>)k1;
            return comparable.compareTo(k2);
        }
    }


    public V get(K key){
        Node<K,V> node = node(key);

        return node != null ? node.value() : null;
    }

    private BSTNode<K,V> node(K key){
        if(_root != null){
            BSTNode<K,V> node = _root;

            while (node != null){
                int comparison = compare(key, node.key());

                if(comparison == 0){
                    return node;
                }else if(comparison <= 0){
                    node = node._left;
                }else {
                    node = node._right;
                }
            }
        }

        return null;
    }

    public void add(K key, V value){
        if(key == null){
            throw new IllegalArgumentException("key");
        }

        if(_root == null){
            _root = new BSTNode<K, V>(key, value);
        }

        BSTNode<K,V> prev = null, curr = _root;
        boolean lastChildLeft = false;
        while(curr != null){
            int comparison = compare(key, curr.key());
            prev = curr;

            if(comparison == 0){
                curr._value = value;
                return;
            }else if(comparison < 0){
                curr = curr._left;
                lastChildLeft = true;
            }
            else{
                curr = curr._right;
                lastChildLeft = false;
            }
        }

        mod++;
        if(lastChildLeft){
            prev._left = new BSTNode<K, V>(key, value);
        }else {
            prev._right = new BSTNode<K, V>(key, value);
        }
    }

    private void removeNode(BSTNode<K,V> curr){
        if(curr.left() == null && curr.right() == null){
            if(curr == _root){
                _root = null;
            }else{
                if(curr.isLeft()) curr._parent._left = null;
                else curr._parent._right = null;
            }
        }
        else if(curr._left == null && curr._right != null){
            curr._key = curr._right._key;
            curr._value = curr._right._value;
            curr._left = curr._right._left;
            curr._right = curr._right._right;
        }
        else if(curr._left != null && curr._right == null){
            curr._key = curr._left._key;
            curr._value = curr._left._value;
            curr._right = curr._left._right;
            curr._left = curr._left._left;
        }
        else { // both left & right exist
            BSTNode<K,V> x = curr._left;
            // find right-most node of left sub-tree
            while (x._right != null){ 
                x = x._right;
            }
            // move that to current
            curr._key = x._key;
            curr._value = x._value;
            // delete duplicate data
            removeNode(x);
        }
    }


    public V remove(K key){
        BSTNode<K,V> curr = _root;
        V val = null;
        while(curr != null){
            int comparison = compare(key, curr.key());
            if(comparison == 0){
                val = curr._value;
                removeNode(curr);
                mod++;
                break;
            }else if(comparison < 0){
                curr = curr._left;
            }
            else{
                curr = curr._right;
            }
        }

        return val;
    }

    public Iterator<MyTree.Node<K,V>> iterator(){
        return new MyIterator();
    }

    private class MyIterator implements Iterator<Node<K,V>>{
        int _startMod;
        Stack<BSTNode<K,V>> _stack;

        public MyIterator(){
            _startMod = MyBSTree.this.mod;
            _stack = new Stack<BSTNode<K, V>>();

            BSTNode<K,V> node = MyBSTree.this._root;
            while (node != null){
                _stack.push(node);
                node = node._left;
            }
        }

        public void remove(){
            throw new UnsupportedOperationException();
        }

        public boolean hasNext(){
            if(MyBSTree.this.mod != _startMod){
                throw new ConcurrentModificationException();
            }

            return !_stack.empty();
        }

        public Node<K,V> next(){
            if(MyBSTree.this.mod != _startMod){
                throw new ConcurrentModificationException();
            }

            if(!hasNext()){
                throw new NoSuchElementException();
            }

            BSTNode<K,V> node = _stack.pop();
            BSTNode<K,V> x = node._right;
            while (x != null){
                _stack.push(x);
                x = x._left;
            }

            return node;
        }
    }

    @Override
    public String toString(){
        if(_root == null) return "[]";

        return _root.toString();
    }

    private static class BSTNode<K,V> implements Node<K,V>{
        K _key;
        V _value;
        BSTNode<K,V> _left, _right, _parent;

        public BSTNode(K key, V value){
            if(key == null){
                throw new IllegalArgumentException("key");
            }

            _key = key;
            _value = value;
        }

        public K key(){
            return _key;
        }

        public V value(){
            return _value;
        }

        public Node<K,V> left(){
            return _left;
        }

        public Node<K,V> right(){
            return _right;
        }

        public Node<K,V> parent(){
            return _parent;
        }

        boolean isLeft(){
            if(_parent == null) return false;

            return _parent._left == this;
        }

        boolean isRight(){
            if(_parent == null) return false;

            return _parent._right == this;
        }

        @Override
        public boolean equals(Object o){
            if(o == null){
                return false;
            }

            try{
                BSTNode<K,V> node = (BSTNode<K,V>)o;
                return node._key.equals(_key) && ((_value == null && node._value == null) || (_value != null && _value.equals(node._value)));
            }catch (ClassCastException ex){
                return false;
            }
        }

        @Override
        public int hashCode(){
            int hashCode = _key.hashCode();

            if(_value != null){
                hashCode ^= _value.hashCode();
            }

            return hashCode;
        }

        @Override
        public String toString(){
            String leftStr = _left != null ? _left.toString() : "";
            String rightStr = _right != null ? _right.toString() : "";
            return "["+leftStr+" "+_key+" "+rightStr+"]";
        }
    }
}