Numba JIT slower than pure python with parameterized function
I just wrote a simple benchmark comparing Numba and Julia, together with some discussion.
I'm wondering whether my Numba code could be fixed somehow, or if what I'm trying to do is indeed not supported by Numba.
The idea is to evaluate this function using a JIT-compiled quadrature rule.
g(p) = integrate exp(p*x) with respect to x
This is the simple quadrature function:
@nb.njit
def quad_trap(f,a,b,N):
h = (b-a)/N
integral = h * ( f(a) + f(b) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk)
return integral
I can pass a JIT-compiled function to this function, like this one:
@nb.njit(nb.float64(nb.float64))
def func(x):
return math.exp(x) - 10
And that's about 10-20X faster than pure Python, which is pretty good.
Now, what I would like to do is to pass a function of x and parametrized by p, something along the lines of:
def g(p):
@nb.njit(nb.float64(nb.float64))
def integrand(x):
return math.exp(p*x) - 10
return quad_trap(integrand, -1, 1, 10000)
And doing seems to break down Numba, which becomes incredibly slow even when compared with pure Python.
Am I doing something wrong, or this feature is indeed unsupported by Numba? (I did check the documentation but I don't understand exactly where the problem is). Thanks!
TL;DR: this feature does not appear to be supported by Numba yet.
that's about 10-20X faster than pure Python, which is pretty good.
The Numba function quad_trap
will be compiled the first time you call it. If the type of the parameter change, then Numba will recompile the function again. The compilation time is generally far from being negligible (few milliseconds to few seconds). To avoid this, the solution is generally to specify the type of the parameters. However, AFAIK, this is not possible here (at least not documented) due to the function. That being said, because you certainly benchmark the quad_trap
function with the same function, Numba should not recompile the function because the type of the provided arguments does not change.
doing seems to break down Numba, which becomes incredibly slow even when compared with pure Python.
In recent versions of Numba, it works without warning, but his is because the function integrand
to be recompiled over and over since Numba does not know if its code changed (or the one functions/operators recursively called in this function). In older versions, Numba may complains that the function integrand
read the parameter p
which is read from its parent encompassing function. This is called a closure.
Closure are generally less well supported by compilers because since it is much harder to deal with them (they need to read the variables from the stack of they parent function). One recurrent general problem is that the closure can escape the scope of its parent function and be called outside resulting in an undefined behaviour (since the closure will try to read the defunct stack of a finished function).
One trick is to move the @nb.njit
decorator from integrand
to g
but Numba refuses to compile g
because it does not support closures that could escape the scope of its parent function (due to the issue describe earlier). Note that the closure does not escape the function where it is defined in your case but Numba cannot prove that (since the quad_trap
function is already compiled) and it also unfortunately fails to do that when the function quad_trap
is inlined (while it could theoretically prove this is safe). In fact the documentations states:
Numba now supports inner functions as long as they are non-recursive and only called locally, but not passed as argument or returned as result. The use of closure variables (variables defined in outer scopes) within an inner function is also supported.
I think the @generated_jit
decorator might help to solve such problem but I did not succeed to make it works in your specific case. It should at least help to compile g
at definition time (like integrand
) rather than during the first call.
One solution is to simply not use closures:
@nb.njit
def quad_trap_p(f,a,b,N,p):
h = (b-a)/N
integral = h * ( f(a,p) + f(b,p) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk,p)
return integral
@nb.njit(nb.float64(nb.float64, nb.float64))
def integrand(x, p):
return math.exp(p*x) - 10
def g(p):
return quad_trap_p(integrand, -1, 1, 10000, p)