Find kth smallest element in a binary search tree in Optimum way

Solution 1:

Here's just an outline of the idea:

In a BST, the left subtree of node T contains only elements smaller than the value stored in T. If k is smaller than the number of elements in the left subtree, the kth smallest element must belong to the left subtree. Otherwise, if k is larger, then the kth smallest element is in the right subtree.

We can augment the BST to have each node in it store the number of elements in its left subtree (assume that the left subtree of a given node includes that node). With this piece of information, it is simple to traverse the tree by repeatedly asking for the number of elements in the left subtree, to decide whether to do recurse into the left or right subtree.

Now, suppose we are at node T:

  1. If k == num_elements(left subtree of T), then the answer we're looking for is the value in node T.
  2. If k > num_elements(left subtree of T), then obviously we can ignore the left subtree, because those elements will also be smaller than the kth smallest. So, we reduce the problem to finding the k - num_elements(left subtree of T) smallest element of the right subtree.
  3. If k < num_elements(left subtree of T), then the kth smallest is somewhere in the left subtree, so we reduce the problem to finding the kth smallest element in the left subtree.

Complexity analysis:

This takes O(depth of node) time, which is O(log n) in the worst case on a balanced BST, or O(log n) on average for a random BST.

A BST requires O(n) storage, and it takes another O(n) to store the information about the number of elements. All BST operations take O(depth of node) time, and it takes O(depth of node) extra time to maintain the "number of elements" information for insertion, deletion or rotation of nodes. Therefore, storing information about the number of elements in the left subtree keeps the space and time complexity of a BST.

Solution 2:

A simpler solution would be to do an inorder traversal and keep track of the element currently to be printed (without printing it). When we reach k, print the element and skip rest of tree traversal.

void findK(Node* p, int* k) {
  if(!p || k < 0) return;
  findK(p->left, k);
  --k;
  if(k == 0) { 
    print p->data;
    return;  
  } 
  findK(p->right, k); 
}

Solution 3:

public int ReturnKthSmallestElement1(int k)
    {
        Node node = Root;

        int count = k;

        int sizeOfLeftSubtree = 0;

        while(node != null)
        {

            sizeOfLeftSubtree = node.SizeOfLeftSubtree();

            if (sizeOfLeftSubtree + 1 == count)
                return node.Value;
            else if (sizeOfLeftSubtree < count)
            {
                node = node.Right;
                count -= sizeOfLeftSubtree+1;
            }
            else
            {
                node = node.Left;
            }
        }

        return -1;
    }

this is my implementation in C# based on the algorithm above just thought I'd post it so people can understand better it works for me

thank you IVlad

Solution 4:

//add a java version without recursion

public static <T> void find(TreeNode<T> node, int num){
    Stack<TreeNode<T>> stack = new Stack<TreeNode<T>>();

    TreeNode<T> current = node;
    int tmp = num;

    while(stack.size() > 0 || current!=null){
        if(current!= null){
            stack.add(current);
            current = current.getLeft();
        }else{
            current = stack.pop();
            tmp--;

            if(tmp == 0){
                System.out.println(current.getValue());
                return;
            }

            current = current.getRight();
        }
    }
}

Solution 5:

A simpler solution would be to do an inorder traversal and keep track of the element currently to be printed with a counter k. When we reach k, print the element. The runtime is O(n). Remember the function return type can not be void, it has to return its updated value of k after each recursive call. A better solution to this would be an augmented BST with a sorted position value at each node.

public static int kthSmallest (Node pivot, int k){
    if(pivot == null )
        return k;   
    k = kthSmallest(pivot.left, k);
    k--;
    if(k == 0){
        System.out.println(pivot.value);
    }
    k = kthSmallest(pivot.right, k);
    return k;
}