How to install trax, jax, jaxlib on M1 Mac on macOS 12?

Solution 1:

jaxlib + jax

After someone claimed success using miniforge, I read this and watched this to clarify using Anaconda and miniforge together.

I installed miniforge with Apple's arm64 : Apple Silicon method. For some reason when I ran miniforge's conda init it set up the initialization code in ~/.bash_profile even though I'm using the zsh shell. I tried putting the code manually instead in ~/.zprofile but it wouldn't load on interactive shells, so I just ended up putting it where Anaconda had put its initialization code, in ~/zshrc.

This made miniforge the default manager. Following the very useful video above, I created a ~/.start_anaconda.sh script so I can use Anaconda as an alternative.

With miniforge I

  • created a new conda environment mytraxenv with conda create -n mytraxenv python=3 which has python 3.10.2 at the moment

  • activated the environment: conda activate mytraxenv

  • ran conda install numpy and conda install six to ensure numpy. six and wheel (installed by one of the previous two) were installed in my mytraxenv environment

  • tried again, with a slightly updated release (from here):

    pip install -U pip pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.75-cp310-none-macosx_11_0_arm64.whl

This worked in installing jaxlib!

Then, I followed these instructions to install jax:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

That worked as well. Note that when running import jax in python it currently warns:

/mytraxenv/lib/python3.10/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

trax

No success yet installing trax.