How do I optimise numpy.packbits with numba?

I'm trying to optimise numpy.packbits:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def _numba_pack(arr, div, su):
    for i in prange(div):
        s = 0
        for j in range(i*8, i*8+8):
            s = 2*s + arr[j]
        su[i] = s
        
def numba_packbits(arr):
    div, mod = np.divmod(arr.size, 8)
    su = np.zeros(div + (mod>0), dtype=np.uint8)
    _numba_pack(arr[:div*8], div, su)
    if mod > 0:
        su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
    return su

>>> X = np.random.randint(2, size=99, dtype=bool)
>>> print(numba_packbits(X))
[ 75  24  79  61 209 189 203 187  47 226 170  61   0]

It appears 2 - 2.5 times slower than np.packbits(X). How is this implemeted in numpy internally? Could this be improved in numba?

I work on numpy == 1.21.2 and numba == 0.53.1 installed via conda install. My platform is:

enter image description here

Results:

import benchit
from numpy import packbits
%matplotlib inline
benchit.setparams(rep=5)

sizes = [100000, 300000, 1000000, 3000000, 10000000, 30000000]
N = sizes[-1]
arr = np.random.randint(2, size=N, dtype=bool)
fns = [numba_packbits, packbits]

in_ = {s/1000000: (arr[:s], ) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Millions of bits')
t.plot(logx=True, figsize=(12, 6), fontsize=14)

enter image description here

Update

With the response of Jérôme:

@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64_byJérôme(arr, su, pos):
    for i in range(64):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
       
@njit(parallel=True)
def _numba_pack_byJérôme(arr, div, su):
    for i in prange(div//64):
        _numba_pack_x64_byJérôme(arr[i*8:(i+64)*8], su[i:i+64], i)
    for i in range(div//64*64, div):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
        
def numba_packbits_byJérôme(arr):
    div, mod = np.divmod(arr.size, 8)
    su = np.zeros(div + (mod>0), dtype=np.uint8)
    _numba_pack_byJérôme(arr[:div*8], div, su)
    if mod > 0:
        su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
    return su

Usage:

>>> print(numba_packbits_byJérôme(X))
[ 75  24  79  61 209 189 203 187  47 226 170  61   0]

Results:

enter image description here


Solution 1:

There are several issue with the Numba implementation. One of them is that parallel loops breaks the constant propagation optimization in LLVM-Lite (the JIT-compiler used by Numba). This cause critical information like array strides not to be propagated resulting in a slow scalar implementation instead of an SIMD one, and additional unneded instructions so to compute the offsets. Such issue can also be seen in C code. Numpy added specific macros so help compilers to automatically vectorize the code (ie. use SIMD instructions) when the stride of the working dimension is actually 1.

A solution to overcome the constant propagation issue is to call another Numba function. This function must not be inlined. The signature should be manually provided so the compiler can know the stride of the array is 1 at compilation time and generate a faster code. Finally, the function should work on fixed-size chunks because function calls are expensive and the compiler can vectorize the code. Unrolling the loop with shifts also produce a faster code (although it is uglier). Here is an example:

@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64(arr, su, pos):
    for i in range(64):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]

@njit('void(bool_[::1], int_, uint8[::1])', parallel=True)
def _numba_pack(arr, div, su):
    for i in prange(div//64):
        _numba_pack_x64(arr[i*8:(i+64)*8], su[i:i+64], i)
    for i in range(div//64*64, div):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]

Benchmark

Here are performance results on my 6-core machine (i5-9600KF) with a billion random items as input:

Initial Numba (seq):    189 ms  (x0.7)
Initial Numba (par):    141 ms  (x1.0)
Numpy (seq):             98 ms  (x1.4)
Optimized Numba (par):   35 ms  (x4.0)
Theoretical optimal:     27 ms  (x5.2)  [fully memory-bound case]

This new implementation is 4 times faster than the initial parallel implementation and about 3 times faster than Numpy.


Delving into the generated assembly code

When parallel=False is set and prange is replaced with range, the following assembly code is generated on my Intel processor supporting AVX-2:

.LBB0_7:
    vmovdqu 112(%rdx,%rax,8), %xmm1
    vmovdqa 384(%rsp), %xmm3
    vpshufb %xmm3, %xmm1, %xmm0
    vmovdqu 96(%rdx,%rax,8), %xmm2
    vpshufb %xmm3, %xmm2, %xmm3
    vpunpcklwd  %xmm0, %xmm3, %xmm3
    vmovdqu 80(%rdx,%rax,8), %xmm15
    vmovdqa 368(%rsp), %xmm5
    vpshufb %xmm5, %xmm15, %xmm4
    vmovdqu 64(%rdx,%rax,8), %xmm0
    [...] <------------------------------  ~180 other instructions discarded
    vpcmpeqb    %xmm3, %xmm11, %xmm2
    vpandn  %xmm8, %xmm2, %xmm2
    vpor    %xmm2, %xmm1, %xmm1
    vpcmpeqb    %xmm3, %xmm0, %xmm0
    vpaddb  %xmm0, %xmm1, %xmm0
    vpsubb  %xmm4, %xmm0, %xmm0
    vmovdqu %xmm0, (%r11,%rax)
    addq    $16, %rax
    cmpq    %rax, %rsi
    jne .LBB0_7

The code is not very good because it uses many unneeded instructions (like SIMD comparison instructions probably due to implicit casts from boolean types), a lot of register are temporary stored (register spilling) and also it uses 128-bit AVX vectors instead of 256-bit AVX ones supported on my machine. That being said, the code is vectorized and each loop iteration writes on 16-bytes at once without any conditional branches (except the one of the loop) so the resulting performance is not so bad.

In fact, the Numpy code is much smaller and more efficient. This is why it is about 2 times faster than the sequential Numba code on my machine with big inputs. Here is the hot assembly loop:

4e8:
    mov      (%rdx,%rax,8),%rcx
    bswap    %rcx
    mov      %rcx,0x20(%rsp)
    mov      0x8(%rdx,%rax,8),%rcx
    add      $0x2,%rax
    movq     0x20(%rsp),%xmm0
    bswap    %rcx
    mov      %rcx,0x20(%rsp)
    movhps   0x20(%rsp),%xmm0
    pcmpeqb  %xmm1,%xmm0
    pcmpeqb  %xmm1,%xmm0
    pmovmskb %xmm0,%ecx
    mov      %cl,(%rsi)
    movzbl   %ch,%ecx
    mov      %cl,(%rsi,%r13,1)
    add      %r9,%rsi
    cmp      %rax,%r8
    jg       4e8

It read values by chunks of 8-bytes and compute them partially using 128-bit SSE instructions. 2 bytes are written at per iterations. That being said, it is not optimal either because 256-bit SIMD instructions are not used and I think the code can be optimized further.

When the initial parallel code is used, here is the assembly code of the hot loop:

.LBB3_4:
     movq %r9, %rax
     leaq (%r10,%r14), %r9
     movq %r15, %rsi
     sarq $63, %rsi
     andq %rdx, %rsi
     addq %r11, %rsi
     cmpb $0, (%r14,%rsi)
     setne     %cl
     addb %cl, %cl
     [...] <---------------  56 instructions (with few 256-bit AVX ones)
     orb  %bl, %cl
     orb  %al, %cl
     orb  %dl, %cl
     movq %rbp, %rdx
     movb %cl, (%r8,%r15)
     incq %r15
     decq %rdi
     addq $8, %r14
     cmpq $1, %rdi
     jg   .LBB3_4

The above code is mainly not vectorized and is quite inefficient. It use a lot of instructions (including quite slow ones like setne/cmovlq/cmpb to do many conditional stores) for each iteration just to write 1-byte at a time. Numpy execute about 8 times less instructions for the same amount of written bytes. The inefficiency of this code is mitigated by the use of multiple threads. In the end, the parallel version can be a bit faster on machines with many cores (eg. >= 6).

The improved implementation provided in the beginning of this answer generate a code similar to the above sequential implementation but using multiple thread (so still far from being optimal, but much batter).