OpenCV - Splitting and merging alpha channels slow

I am using Python OpenCV to split channels and remove black background like this...

    b_channel, g_channel, r_channel = cv2.split(image_1)
    alpha_channel = np.zeros_like(gray)

    for p in range(alpha_channel.shape[0]):
        for q in range(alpha_channel.shape[1]):
            if b_channel[p][q]!=0 or g_channel[p][q]!=0 or r_channel[p][q]!=0:
                alpha_channel[p][q] = 255

    merged = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))

This is working, but it is taking around 10 seconds to complete on an image that is only 200kb

Is there a more efficient way to do this or is there some speed gains I could make using the code I have?


Solution 1:

Iterating over pixels using for loop is literally very slow and inefficient. Also, as per the documentation here,

cv2.split() is a costly operation (in terms of time). So do it only if you need it. Otherwise go for Numpy indexing.

You can try vectorising and indexing with numpy as below:

# create the image with alpha channel
img_rgba = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)

# mask: elements are True any of the pixel value is 0         
mask = (img[:, :, 0:3] != [0,0,0]).any(2) 
#assign the mask to the last channel of the image
img_rgba[:,:,3]  = (mask*255).astype(np.uint8)

Solution 2:

For what you're doing, using cv2.bitwise_or seems to be the fastest method:

image_1 = img
# your method
start_time = time.time()
b_channel, g_channel, r_channel = cv2.split(image_1)
alpha_channel = np.zeros_like(gray)
for p in range(alpha_channel.shape[0]):
    for q in range(alpha_channel.shape[1]):
        if b_channel[p][q]!=0 or g_channel[p][q]!=0 or r_channel[p][q]!=0:
            alpha_channel[p][q] = 255
elapsed_time = time.time() - start_time
print('for cycles:  ' + str(elapsed_time*1000.0) + ' milliseconds')

# my method
start_time = time.time()
b_channel, g_channel, r_channel = cv2.split(image_1)
alpha_channel2 = cv2.bitwise_or(g_channel,r_channel)
alpha_channel2 =  cv2.bitwise_or(alpha_channel2, b_channel)
_,alpha_channel2 = cv2.threshold(alpha_channel2,0,255,cv2.THRESH_BINARY)
elapsed_time2 = time.time() - start_time
print('bitwise + threshold:  '+ str(elapsed_time2*1000.0) + ' milliseconds')

# annubhav's method
start_time = time.time()
img_rgba = cv2.cvtColor(image_1, cv2.COLOR_RGB2RGBA)
# mask: elements are True any of the pixel value is 0         
mask = (img[:, :, 0:3] != [0,0,0]).any(2) 
#assign the mask to the last channel of the image
img_rgba[:,:,3]  = (mask*255).astype(np.uint8)
elapsed_time3 = time.time() - start_time
print('anubhav:  ' + str(elapsed_time3*1000.0) + ' milliseconds')

for cycles: 2146.300792694092 milliseconds

bitwise + threshold: 4.959583282470703 milliseconds

anubhav: 27.924776077270508 milliseconds