The HashSet<T>.removeAll method is surprisingly slow

Jon Skeet recently raised an interesting programming topic on his blog: "There's a hole in my abstraction, dear Liza, dear Liza" (emphasis added):

I have a set – a HashSet, in fact. I want to remove some items from it… and many of the items may well not exist. In fact, in our test case, none of the items in the "removals" collection will be in the original set. This sounds – and indeed is – extremely easy to code. After all, we’ve got Set<T>.removeAll to help us, right?

We specify the size of the "source" set and the size of the "removals" collection on the command line, and build both of them. The source set contains only non-negative integers; the removals set contains only negative integers. We measure how long it takes to remove all the elements using System.currentTimeMillis(), which isn’t the world most accurate stopwatch but is more than adequate in this case, as you’ll see. Here’s the code:

import java.util.*;
public class Test 
{ 
    public static void main(String[] args) 
    { 
       int sourceSize = Integer.parseInt(args[0]); 
       int removalsSize = Integer.parseInt(args[1]); 
        
       Set<Integer> source = new HashSet<Integer>(); 
       Collection<Integer> removals = new ArrayList<Integer>(); 
        
       for (int i = 0; i < sourceSize; i++) 
       { 
           source.add(i); 
       } 
       for (int i = 1; i <= removalsSize; i++) 
       { 
           removals.add(-i); 
       } 
        
       long start = System.currentTimeMillis(); 
       source.removeAll(removals); 
       long end = System.currentTimeMillis(); 
       System.out.println("Time taken: " + (end - start) + "ms"); 
    }
}

Let’s start off by giving it an easy job: a source set of 100 items, and 100 to remove:

c:UsersJonTest>java Test 100 100
Time taken: 1ms

Okay, so we hadn’t expected it to be slow… clearly we can ramp things up a bit. How about a source of one million items and 300,000 items to remove?

c:UsersJonTest>java Test 1000000 300000
Time taken: 38ms

Hmm. That still seems pretty speedy. Now I feel I’ve been a little bit cruel, asking it to do all that removing. Let’s make it a bit easier – 300,000 source items and 300,000 removals:

c:UsersJonTest>java Test 300000 300000
Time taken: 178131ms

Excuse me? Nearly three minutes? Yikes! Surely it ought to be easier to remove items from a smaller collection than the one we managed in 38ms?

Can someone explain why this is happening? Why is the HashSet<T>.removeAll method so slow?


The behaviour is (somewhat) documented in the javadoc:

This implementation determines which is the smaller of this set and the specified collection, by invoking the size method on each. If this set has fewer elements, then the implementation iterates over this set, checking each element returned by the iterator in turn to see if it is contained in the specified collection. If it is so contained, it is removed from this set with the iterator's remove method. If the specified collection has fewer elements, then the implementation iterates over the specified collection, removing from this set each element returned by the iterator, using this set's remove method.

What this means in practice, when you call source.removeAll(removals);:

  • if the removals collection is of a smaller size than source, the remove method of HashSet is called, which is fast.

  • if the removals collection is of equal or larger size than the source, then removals.contains is called, which is slow for an ArrayList.

Quick fix:

Collection<Integer> removals = new HashSet<Integer>();

Note that there is an open bug that is very similar to what you describe. The bottom line seems to be that it is probably a poor choice but can't be changed because it is documented in the javadoc.


For reference, this is the code of removeAll (in Java 8 - haven't checked other versions):

public boolean removeAll(Collection<?> c) {
    Objects.requireNonNull(c);
    boolean modified = false;

    if (size() > c.size()) {
        for (Iterator<?> i = c.iterator(); i.hasNext(); )
            modified |= remove(i.next());
    } else {
        for (Iterator<?> i = iterator(); i.hasNext(); ) {
            if (c.contains(i.next())) {
                i.remove();
                modified = true;
            }
        }
    }
    return modified;
}