Speed up computation for Distance Transform on Image in Python

Solution 1:

The implementation in the OP is a brute-force approach to the distance transform. This algorithm is O(n2), as it computes the distance from each background pixel to each foreground pixel. Furthermore, because of the way it is vectorized, it requires a lot of memory. On my computer it couldn't compute the distance transform of a 256x256 image without thrashing. Many other algorithms are described in the literature, below I'll discuss two O(n) algorithms.

Note: Typically, the distance transform is computed for object pixels (value 1) to the nearest background pixel (value 0). The code in the OP does the reverse, and so the code I've pasted below follows OP's convention, not the more common convention.


The easiest to implement, IMO, is the chamfer distance algorithm. This is a recursive algorithm that does two passes over the image: one left to right and top to bottom, and one right to left and bottom to top. In each pass, the distance computed for previous pixels is propagated. This algorithm can be implemented using integer distances or floating-point distances between neighbors. The latter yields smaller errors, of course. But in both cases the errors can be reduced significantly by increasing the number of neighbors queried in this propagation. The algorithm is older, but G. Borgefors analyzed it and proposed suitable neighbor distances (G. Borgefors, Distance Transformations in Digital Images, Computer Vision, Graphics, and Image Processing 34:344-371, 1986).

Here is an implementation using 3-4 distance (distance to edge-connected neighbors is 3, distance to vertex-connected neighbors is 4):

def chamfer_distance(img):
   w, h = img.shape
   dt = np.zeros((w,h), np.uint32)
   # Forward pass
   x = 0
   y = 0
   if img[x,y] == 0:
      dt[x,y] = 65535 # some large value
   for x in range(1, w):
      if img[x,y] == 0:
         dt[x,y] = 3 + dt[x-1,y]
   for y in range(1, h):
      x = 0
      if img[x,y] == 0:
         dt[x,y] = min(3 + dt[x,y-1], 4 + dt[x+1,y-1])
      for x in range(1, w-1):
         if img[x,y] == 0:
            dt[x,y] = min(4 + dt[x-1,y-1], 3 + dt[x,y-1], 4 + dt[x+1,y-1], 3 + dt[x-1,y])
      x = w-1
      if img[x,y] == 0:
         dt[x,y] = min(4 + dt[x-1,y-1], 3 + dt[x,y-1], 3 + dt[x-1,y])
   # Backward pass
   for x in range(w-2, -1, -1):
      y = h-1
      if img[x,y] == 0:
         dt[x,y] = min(dt[x,y], 3 + dt[x+1,y])
   for y in range(h-2, -1, -1):
      x = w-1
      if img[x,y] == 0:
         dt[x,y] = min(dt[x,y], 3 + dt[x,y+1], 4 + dt[x-1,y+1])
      for x in range(1, w-1):
         if img[x,y] == 0:
            dt[x,y] = min(dt[x,y], 4 + dt[x+1,y+1], 3 + dt[x,y+1], 4 + dt[x-1,y+1], 3 + dt[x+1,y])
      x = 0
      if img[x,y] == 0:
         dt[x,y] = min(dt[x,y], 4 + dt[x+1,y+1], 3 + dt[x,y+1], 3 + dt[x+1,y])
   return dt

Note that a lot of the complication here is to avoid indexing out of bounds, but still computing distances all the way to the edges of the image. If we simply skip the pixels around the border of the image, the code becomes much simpler.

Because it is a recursive algorithm, it is not possible to vectorize its implementation. The Python code will not be very efficient. But programmed in C or the like will yield a very fast algorithm that yields a fairly good approximation to the Euclidean distance.

OpenCV's cv.distanceTransform implements this algorithm.


Another very efficient algorithm computes the square of the distance transform. The square distance is separable (i.e. can be computed independently for each axis and added). This leads to an algorithm that is easy to parallelize. For each image row, the algorithm does a forward and a backward pass. For each column in the result, the algorithm then does another forward and backward pass. This process leads to an exact Euclidean distance transform.

This algorithm was first proposed by R. van den Boomgaard in his Ph.D. thesis in 1992. Unfortunately this went unnoticed. The algorithm was then again proposed by A. Meijster, J.B.T.M. Roerdink and W.H. Hesselink (A General Algorithm for Computing Distance Transforms in Linear Time, Mathematical Morphology and its Applications to Image and Signal Processing, pp 331-340, 2002), and again by P. Felzenszwalb and D. Huttenlocher (Distance transforms of sampled functions, Technical report, Cornell University, 2004).

This is the most efficient algorithm known, in part because it is the only one that can be easily and efficiently parallelized (computation on each image row, and later on each image column, is independent of other rows/columns).

Unfortunately I don't have any Python code for this one to share, but you can find implementations online. For example OpenCV's cv.distanceTransform implements this algorithm, and DIPlib's dip.EuclideanDistanceTransform does too.