I have a JAX code that has nested scan loops with vmap in the inner loop. The code takes a long time to compile and execute. I read some articles and learned the following things.
scan is making the computational graph deep, which makes computation time longer.scan does not enjoy parallelization due to the dependency of the previous loop, its execution is slow, especially on GPUs.Based on this knowledge (I am not sure if they are correct), I made the outer scan loop a Python loop. In the outer loop, there is no @jit.
The result is that it helps reduce compilation time, but execution time remains almost the same.
Is this a general or expected behavior? I expected that the execution time would also be reduced.
The key thing to know is that JAX flattens for loops during tracing: so if you have a for loop with N iterations, JAX will effectively generate a program with N copies of the loop body. Compared to an equivalent scan, this has two effects:
It leads to longer computation time, because more lines in your program mean the compiler has more degrees of freedom and has to spend more time exploring these degrees of freedom for possible optimizations
It (often, but not always) leads to faster runtime, because the compiler has more degrees of freedom to exploit when optimizing the code. Typically speaking though, the added compilation cost will far outweigh the improved runtime, especially for a large number of iterations.
So when replacing a scan with a for loop, you'd expect compilation time to increase, and runtime to either stay the same or marginally improve, depending on what's in the loop body and whether the compiler can optimize operations across iterations.
Edit based on edited question:
Outside of jax.jit, each JAX operation is asynchronously dispatched, meaning that the Python code will generally run ahead of the actual computation. In the case of the for loop outside jax.jit, each operation within the loop body will be individually dispatched in sequence. If you're attempting to benchmark such code, you should be careful to account for the effects of asynchronous dispatch by adding jax.block_until_ready around the final result, otherwise you may just be measuring dispatch time rather than the runtime of the actual computation (see FAQ: Benchmarking JAX Code for more discussion of this).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With