Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Installing jaxlib for cuda 11.8

Tags:

pip

ubuntu

jax

I'm trying to install jax and jaxlib on my Ubuntu 18 with python 3.8 for snerg (https://github.com/google-research/google-research/tree/master/snerg). Unfortunately when I try to install jax and jaxlib for Cuda 11.8 with the following command :

pip install --upgrade jax jaxlib==0.1.69+cuda118 -f https://storage.googleapis.com/jax-releases/jax_releases.html 

I get the following error:

ERROR: Ignored the following versions that require a different python version: 0.4.14 Requires-Python >=3.9
ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.69+cuda118 (from versions: 0.1.32, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.46, 0.1.50, 0.1.51, 0.1.52, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.8, 0.3.10, 0.3.14, 0.3.15, 0.3.18, 0.3.20, 0.3.22, 0.3.24, 0.3.25, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13)
ERROR: No matching distribution found for jaxlib==0.1.69+cuda118

Would appreciate any help. Thanks

like image 558
m0j1 Avatar asked Sep 05 '25 00:09

m0j1


1 Answers

Follow the following instructions which are primarily obtained from the source:

Uninstall previous versions (if any):

$ pip uninstall jax jaxlib jaxtyping -y

Upgrade your pip:

$ pip install --upgrade pip

Find out which CUDA is already installed on your machine:

$ nvidia-smi

Thu Jan  4 11:24:58 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A1000 6GB Lap...    Off | 00000000:01:00.0 Off |                  N/A |
| N/A   58C    P0              12W /  35W |      8MiB /  6144MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      3219      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+

Depending on the CUDA version of your machine( wheels only available on linux ), run EITHER of the following:

# CUDA 12.X installation
$ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

#### OR ####

# CUDA 11.X installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

To double check if you have have successfully configured the gpu:

$ python -c "import jax; print(f'Jax backend: {jax.default_backend()}')"
Jax backend: gpu 
like image 91
Färid Alijani Avatar answered Sep 07 '25 19:09

Färid Alijani