Does JAX run slower than NumPy?

I've recently started to learn JAX. I have written a short snippet in NumPy and have written its equivalent in JAX. I was expecting JAX to be faster but when I profile the codes, the NumPy code is way faster than the JAX code. I was wondering if this is generally true or if there is an issue in my implementation.

NumPy code:

import numpy as np
from time import time as tm


def gp(x):
    return np.maximum(np.zeros(x.shape), x)


# -- inputs
n_q1 = 25
n_q2 = 5

x = np.random.rand(1, n_q1)  # todo: changes w/ time
y = np.random.rand(1, n_q2)  # todo: changes w/ time

# -- activations
n_p1 = 3
n_p2 = 2

v_p1 = np.random.rand(1, n_p1)
v_p2 = np.random.rand(1, n_p2)

a_p1 = 0.5
a_p2 = 0.5

# -- weights
W_q1p1 = np.random.rand(n_q1, n_p1)
W_p2q2 = np.random.rand(n_p2, n_q2)
W_p1p1 = np.random.rand(n_p1, n_p1)
W_p1p2 = np.random.rand(n_p1, n_p2)
W_p2p1 = np.random.rand(n_p2, n_p1)

# -- computation
t1=tm()

for t in range(2000):

    z_p1 = np.matmul(v_p1, W_p1p1) + np.matmul(v_p2, W_p2p1) + np.matmul(x, W_q1p1)
    v_p1_new = a_p1 * v_p1 + (1 - a_p1) * gp(z_p1)

    z_p2 = np.matmul(v_p1, W_p1p2)
    v_p2_new = a_p2 * v_p2 + (1 - a_p2) * gp(z_p2)

    v_p1, v_p2 = v_p1_new, v_p2_new

print(tm()-t1)

This yields: 0.02118515968322754

JAX code:

from jax import random, nn, numpy as jnp

from time import time as tm


def gp(x):
    return nn.relu(x)


# -- inputs
n_q1 = 25
n_q2 = 5
key = random.PRNGKey(0)

x = random.normal(key, (1, n_q1))
y = random.normal(key, (1, n_q2))  # todo: check if I need to advance "key" manually

# -- activations
n_p1 = 3
n_p2 = 2

v_p1 = random.normal(key, (1, n_p1))
v_p2 = random.normal(key, (1, n_p2))

a_p1 = 0.5
a_p2 = 0.5

# -- weights
W_q1p1 = random.normal(key, (n_q1, n_p1))
W_p2q2 = random.normal(key, (n_p2, n_q2))

W_p1p1 = random.normal(key, (n_p1, n_p1))
W_p1p2 = random.normal(key, (n_p1, n_p2))
W_p2p1 = random.normal(key, (n_p2, n_p1))

# -- computation
t1=tm()
for t in range(2000):

    z_p1 = jnp.matmul(v_p1, W_p1p1) + jnp.matmul(v_p2, W_p2p1) + jnp.matmul(x, W_q1p1)
    v_p1_new = a_p1 * v_p1 + (1 - a_p1) * gp(z_p1)

    z_p2 = jnp.matmul(v_p1, W_p1p2)
    v_p2_new = a_p2 * v_p2 + (1 - a_p2) * gp(z_p2)

    v_p1, v_p2 = v_p1_new, v_p2_new
    
print(tm()-t1)

This yields: 2.5548229217529297


Solution 1:

The JAX documentation has a useful section on this in its FAQ: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy

TL;DR: it's complicated. For individual matrix operations on CPU, JAX is often slower than NumPy, but JIT-compiled sequences of operations in JAX are often faster than NumPy, and once you move to GPU/TPU, JAX will generally be much faster than NumPy.