Stack with find-min/find-max more efficient than O(n)?

I am interested in creating a Java data structure similar to a stack that supports the following operations as efficiently as possible:

  • Push, which adds a new element atop the stack,
  • Pop, which removes the top element of the stack,
  • Find-Max, which returns (but does not remove) the largest element of the stack, and
  • Find-Min, which returns (but does not remove) the smallest element of the stack, and

What would be the fastest implementation of this data structure? How might I go about writing it in Java?


This is a classic data structures question. The intuition behind the problem is as follows - the only way that the maximum and minimum can change is if you push a new value onto the stack or pop a new value off of the stack. Given this, suppose that at each level in the stack you keep track of the maximum and minimum values at or below that point in the stack. Then, when you push a new element onto the stack, you can easily (in O(1) time) compute the maximum and minimum value anywhere in the stack by comparing the new element you just pushed to the current maximum and minimum. Similarly, when you pop off an element, you will expose the element in the stack one step below the top, which already has the maximum and minimum values in the rest of the stack stored alongside it.

Visually, suppose that we have a stack and add the values 2, 7, 1, 8, 3, and 9, in that order. We start by pushing 2, and so we push 2 onto our stack. Since 2 is now the largest and smallest value in the stack as well, we record this:

 2  (max 2, min 2)

Now, let's push 7. Since 7 is greater than 2 (the current max), we end up with this:

 7  (max 7, min 2)
 2  (max 2, min 2)

Notice that right now we can read off the max and min of the stack by looking at the top of the stack and seeing that 7 is the max and 2 is the min. If we now push 1, we get

 1  (max 7, min 1)
 7  (max 7, min 2)
 2  (max 2, min 2)

Here, we know that 1 is the minimum, since we can compare 1 to the cached min value stored atop the stack (2). As an exercise, make sure you understand why after adding 8, 3, and 9, we get this:

 9  (max 9, min 1)
 3  (max 8, min 1)
 8  (max 8, min 1)
 1  (max 7, min 1)
 7  (max 7, min 2)
 2  (max 2, min 2)

Now, if we want to query the max and min, we can do so in O(1) by just reading off the stored max and min atop the stack (9 and 1, respectively).

Now, suppose that we pop off the top element. This yields 9 and modifies the stack to be

 3  (max 8, min 1)
 8  (max 8, min 1)
 1  (max 7, min 1)
 7  (max 7, min 2)
 2  (max 2, min 2)

And now notice that the max of these elements is 8, exactly the correct answer! If we then pushed 0, we'd get this:

 0  (max 8, min 0)
 3  (max 8, min 1)
 8  (max 8, min 1)
 1  (max 7, min 1)
 7  (max 7, min 2)
 2  (max 2, min 2)

And, as you can see, the max and min are computed correctly.

Overall, this leads to an implementation of the stack that has O(1) push, pop, find-max, and find-min, which is as asymptotically as good as it gets. I'll leave the implementation as an exercise. :-) However, you may want to consider implementing the stack using one of the standard stack implementation techniques, such as using a dynamic array or linked list of objects, each of which holds the stack element, min, and max. You could do this easily by leveraging off of ArrayList or LinkedList. Alternatively, you could use the provided Java Stack class, though IIRC it has some overhead due to synchronization that might be unnecessary for this application.

Interestingly, once you've built a stack with these properties, you can use it as a building block to construct a queue with the same properties and time guarantees. You can also use it in a more complex construction to build a double-ended queue with these properties as well.

Hope this helps!

EDIT: If you're curious, I have C++ implementations of a min-stack and a the aforementioned min-queue on my personal site. Hopefully this shows off what this might look like in practice!


Although the answer is right, but we can do better. If the stack has lot of elements, then we are wasting a lot of space. However, we can save this useless space as follow:

Instead of saving min(or max) value with each element, we can use two stacks. Because change in the minimum(or maximum) value will not be so frequent, we push the min(or max) value to its respective stack only when the new value is <=(or >=) to the current min(or max) value.

Here is the implementation in Java:

public class StackWithMinMax extends Stack<Integer> {

    private Stack<Integer> minStack;
    private Stack<Integer> maxStack;

    public StackWithMinMax () {
        minStack = new Stack<Integer>();    
        maxStack = new Stack<Integer>();    
    }

    public void push(int value){
        if (value <= min()) { // Note the '=' sign here
            minStack.push(value);
        }

        if (value >= max()) {
            maxStack.push(value);
        }

        super.push(value);
    }

    public Integer pop() {
        int value = super.pop();

        if (value == min()) {
            minStack.pop();         
        }

        if (value == max()) {
            maxStack.pop();         
        }

        return value;
    }

    public int min() {
        if (minStack.isEmpty()) {
            return Integer.MAX_VALUE;
        } else {
            return minStack.peek();
        }
    }

    public int max() {
        if (maxStack.isEmpty()) {
            return Integer.MIN_VALUE;
        } else {
            return maxStack.peek();
        }
    }
}

Note that using this approach, we would have very few elements in minStack & maxStack, thus saving space. e.g.

Stack : MinStack : MaxStack

7         7         7
4         4         7
5         1         8 (TOP)
6         1 (TOP)         
7
8                 
1                  
1                  
7
2
4
2 (TOP)     

May be too late to reply but just for the sake of record. Here is the java code.

import java.util.ArrayList;
import java.util.List;

public class MinStack {
    List<Node> items;

    public void push(int num) {
        if (items == null) {
            items = new ArrayList<Node>();
        }
        Node node = new Node(num);
        if (items.size() > 0) {
            node.min = Math.min(items.get(items.size() - 1).min, num);
            node.max = Math.max(items.get(items.size() - 1).max, num);

        } else {
            node.min = num;
            node.max = num;
        }
        items.add(node);
        printStack();
    }

    public Node pop() {
        Node popThis = null;
        if (items != null && items.size() > 0) {
            popThis = this.items.get(items.size() - 1);
            items.remove(items.size() - 1);         
        }
        printStack();
        return popThis;
    }

    public int getMin() {
        if (items != null && items.size() > 0) {
            int min = this.items.get(items.size() - 1).min;
            System.out.println("Minimum Element > " + min);
            return min;
        }
        return -1;
    }

    public int getMax() {
        if (items != null && items.size() > 0) {
            int max = this.items.get(items.size() - 1).max;
            System.out.println("Maximum Element > " + max);
            return max;
        }
        return -1;
    }

    public void printStack() {
        int i = 0;
        for (Node n : items) {
            System.out.print(n.data + " > ");
            if (i == items.size() - 1) {
                System.out.print(" | Min = " + n.min + " |");
                System.out.print(" | Max = " + n.max + " |");

            }
            i++;
        }
        System.out.println();
    }

    public static void main(String args[]) {
        MinStack stack = new MinStack();
        stack.push(10);

        stack.push(13);
        stack.push(19);
        stack.push(3);
        stack.push(2);
        stack.push(2);
        stack.printStack();
        stack.pop();
        //stack.getMin();
        stack.printStack();

    }
}

Stack Class:

class Node {

        int data;
        int min;
        int max;

        public Node(int data) {
            super();
            this.data = data;
        }

        public Node() {
            super();
        }
    }