Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

JAX does jitting functions separately change performances?

Tags:

jit

jax

I am learning to use JAX and I have some doubts about the use of jit and vmap that I couldn't solve by reading the docs.

  1. Does it make a difference to jit several functions separately and then jit the function that uses them? For example, if I have the functions foo() and bar() and a function

    @jax.jit 
    def fooBar(x):
        return foo(x) + bar(x)
    

    Is there any difference if foo() and bar() are already jitted?

  2. Should I jit a function after I vmap it? In the example above, should I do jax.jit(jax.vmal(fooBar)) or just jax.vmap(fooBar)?

like image 406
Federico Taschin Avatar asked Jan 17 '26 12:01

Federico Taschin


1 Answers

When it comes to performance of code execution, there is no difference between jitting functions separately and jitting once at the outer function (functionally there is one subtle difference: jit-compiling the inner function will wrap the contents in an xla_call primitive, but this makes little to no difference for the final compilation & execution).

When using vmap on the other hand, there is no implicit compilation. vmap(f) will be executed in eager mode, while jit(vmap(f)) will be just-in-time compiled and generally result in faster execution.

like image 64
jakevdp Avatar answered Jan 21 '26 07:01

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!