Skip to content

FAQ¤

Best practices with jit and vmap¤

JIT¤

When using jax.jit together with quax.quaxify, the recommended pattern is to apply jax.jit at the outermost level:

import jax
import quax

# Preferred: jit wraps the entire quaxified computation.
jit_fn = jax.jit(quax.quaxify(fn))

# Also works, but less optimal: quaxify wraps an already-jitted function.
jit_fn = quax.quaxify(jax.jit(fn))

Placing jit at the outermost level follows the same best practice that applies to other JAX transforms: it allows JAX to compile the entire computation, including the dispatch logic introduced by quaxify, as a single unit. The reversed ordering (quax.quaxify(jax.jit(fn))) does work correctly, but may compile less efficiently because each inner jitted call is compiled in isolation before quaxify sees it.

vmap¤

When using jax.vmap together with quax.quaxify, the ordering does not affect correctness — place the transforms in whichever order best describes the operations you want to perform:

import jax
import quax

# Both orderings are valid; choose whichever matches your intent.

# Batch over inputs, then quaxify:
vmap_fn = quax.quaxify(jax.vmap(fn))

# Or equivalently:
vmap_fn = jax.vmap(quax.quaxify(fn))