Union of intervals

I've got a class representing an interval. This class has two properties "start" and "end" of a comparable type. Now I'm searching for an efficient algorithm to take the union of a set of such intervals.

Thanks in advance.


Sort them by one of the terms (start, for example), then check for overlaps with its (right-hand) neighbor as you move through the list.

class tp:
    def __repr__(self):
        return "(%d,%d)" % (self.start, self.end)

    def __init__(self, start, end):
        self.start = start
        self.end = end


s = [tp(5, 10), tp(7, 8), tp(0, 5)]
s.sort(key=lambda self: self.start)
y = [s[0]]
for x in s[1:]:
    if y[-1].end < x.start:
        y.append(x)
    elif y[-1].end == x.start:
        y[-1].end = x.end

Use the sweep line algorithm. Basically, you sort all the values in a list (while keeping whether it's beginning or end of the interval along with each item). This operation is O(n log n). Then you loop in a single pass along the sorted items and compute the intervals O(n).

O(n log n) + O(n) = O(n log n)


It turns out this problem has been solved, many times over -- at varying levels of fancy, going under nomenclature(s): http://en.wikipedia.org/wiki/Interval_tree , http://en.wikipedia.org/wiki/Segment_tree , and also 'RangeTree'

(as OP's question involves large counts of intervals these datastructures matter )


in terms of my own choice of python library selection:

  • From testing, I'm finding that what most nails it in terms of being full featured and python current ( non bit-rotted ) : the 'Interval' and 'Union' classes from SymPy, see : http://sympystats.wordpress.com/2012/03/30/simplifying-sets/

  • Another good looking choice, a higher performance but less feature rich option (eg. didn't work on floating point range removal) : https://pypi.python.org/pypi/Banyan

Finally: search around on SO itself, under any of IntervalTree, SegmentTree, RangeTree, and you'll find answers/hooks further galore


The algorithm by geocar fails when:

s=[tp(0,1),tp(0,3)]

I'm not very sure but I think this is the correct way:

class tp():
    def __repr__(self):
        return '(%.2f,%.2f)' % (self.start, self.end)
    def __init__(self,start,end): 
        self.start=start
        self.end=end
s=[tp(0,1),tp(0,3),tp(4,5)]
s.sort(key=lambda self: self.start)
print s
y=[ s[0] ]
for x in s[1:]:
    if y[-1].end < x.start:
        y.append(x)
    elif y[-1].end == x.start:
        y[-1].end = x.end
    if x.end > y[-1].end:
        y[-1].end = x.end
print y

I also implemented it for subtraction:

#subtraction
z=tp(1.5,5) #interval to be subtracted
s=[tp(0,1),tp(0,3), tp(3,4),tp(4,6)]

s.sort(key=lambda self: self.start)
print s
for x in s[:]:
    if z.end < x.start:
        break
    elif z.start < x.start and z.end > x.start and z.end < x.end:
        x.start=z.end
    elif z.start < x.start and z.end > x.end:
        s.remove(x)
    elif z.start > x.start and z.end < x.end:
        s.append(tp(x.start,z.start))
        s.append(tp(z.end,x.end))
        s.remove(x)
    elif z.start > x.start and z.start < x.end and z.end > x.end:
        x.end=z.start
    elif z.start > x.end:
        continue

print s

Sort all the points. Then go through the list incrementing a counter for "start" points, and decrementing it for "end" points. If the counter reaches 0, then it really is an endpoint of one of the intervals in the union.

The counter will never go negative, and will reach 0 at the end of the list.