Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cannot import name 'linear_util' from 'jax'

I'm trying to reproduce the experiments of the S5 model, https://github.com/lindermanlab/S5, but I encountered some issues when solving the environment. When I'm running the shell script./run_lra_cifar.sh, I get the following error

Traceback (most recent call last):
  File "/Path/S5/run_train.py", line 3, in <module>
    from s5.train import train
  File "/Path/S5/s5/train.py", line 7, in <module>
    from .train_helpers import create_train_state, reduce_lr_on_plateau,\
  File "/Path/train_helpers.py", line 6, in <module>
    from flax.training import train_state
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
    from . import core
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
    from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)

I'm running this on an RTX4090 and my CUDA version is 11.8. My jax version is 0.4.25 and jaxlib version is 0.4.25+cuda11.cudnn86

I first tried to install the dependencies using the author's

pip install -r requirements_gpu.txt

However, this doesn't seem to work in my case since I can't evenimport jax. So I installed jax according to the instructions on https://jax.readthedocs.io/en/latest/installation.html by typing

pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

So far I've tried:

  1. Using a older GPU(3060 and 2070)
  2. Downgrading python to 3.9

Does anyone know what could be wrong? Any help is appreciated

like image 794
WillWu Avatar asked Dec 01 '25 16:12

WillWu


1 Answers

jax.linear_util was deprecated in JAX v0.4.16 and removed in JAX v0.4.24.

It appears that flax is the source of the linear_util import, meaning that you are using an older flax version with a newer jax version.

To fix your issue, you'll either need to install an older version of JAX which still has jax.linear_util, or update to a newer version of flax which is compatible with more recent JAX versions.

like image 180
jakevdp Avatar answered Dec 03 '25 06:12

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!