I am learning to use JAX and I have some doubts about the use of jit and vmap that I couldn't solve by reading the docs.
Does it make a difference to jit several functions separately and then jit the function that uses them? For example, if I have the functions foo() and bar() and a function
@jax.jit
def fooBar(x):
return foo(x) + bar(x)
Is there any difference if foo() and bar() are already jitted?
Should I jit a function after I vmap it? In the example above, should I do jax.jit(jax.vmal(fooBar)) or just jax.vmap(fooBar)?
When it comes to performance of code execution, there is no difference between jitting functions separately and jitting once at the outer function (functionally there is one subtle difference: jit-compiling the inner function will wrap the contents in an xla_call primitive, but this makes little to no difference for the final compilation & execution).
When using vmap on the other hand, there is no implicit compilation. vmap(f) will be executed in eager mode, while jit(vmap(f)) will be just-in-time compiled and generally result in faster execution.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With