I'm trying to install jax and jaxlib on my Ubuntu 18 with python 3.8 for snerg (https://github.com/google-research/google-research/tree/master/snerg). Unfortunately when I try to install jax and jaxlib for Cuda 11.8 with the following command :
pip install --upgrade jax jaxlib==0.1.69+cuda118 -f https://storage.googleapis.com/jax-releases/jax_releases.html
I get the following error:
ERROR: Ignored the following versions that require a different python version: 0.4.14 Requires-Python >=3.9
ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.69+cuda118 (from versions: 0.1.32, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.46, 0.1.50, 0.1.51, 0.1.52, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.8, 0.3.10, 0.3.14, 0.3.15, 0.3.18, 0.3.20, 0.3.22, 0.3.24, 0.3.25, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13)
ERROR: No matching distribution found for jaxlib==0.1.69+cuda118
Would appreciate any help. Thanks
Follow the following instructions which are primarily obtained from the source:
Uninstall previous versions (if any):
$ pip uninstall jax jaxlib jaxtyping -y
Upgrade your pip:
$ pip install --upgrade pip
Find out which CUDA is already installed on your machine:
$ nvidia-smi
Thu Jan 4 11:24:58 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A1000 6GB Lap... Off | 00000000:01:00.0 Off | N/A |
| N/A 58C P0 12W / 35W | 8MiB / 6144MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3219 G /usr/lib/xorg/Xorg 4MiB |
+---------------------------------------------------------------------------------------+
Depending on the CUDA version of your machine( wheels only available on linux ), run EITHER of the following:
# CUDA 12.X installation
$ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#### OR ####
# CUDA 11.X installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
To double check if you have have successfully configured the gpu:
$ python -c "import jax; print(f'Jax backend: {jax.default_backend()}')"
Jax backend: gpu
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With