This is meant to provide some conceptual background that can be helpful for understanding JAX and Equinox. I wrote it for some coworkers who were ML researchers accustomed to working with PyTorch but didn't know much about programming language theory or implementation. I hope this page will make it much easier for people like that to come up to speed with these libraries.
JAX and Equinox both have good official documentation, so I encourage you to check those out after (or before) reading this.
JAX¶
I think it's helpful to think of JAX as a domain-specific language embedded within Python rather than a Python library. Why have a domain specific language? Basically because it's very hard to get arbitrary Python code to run fast. JAX is less flexible and harder to learn than PyTorch, but in return it tends to run faster and be easier to scale.
Besides the compiler that tends to make code faster, some powerful things that JAX makes easier than PyTorch are:
- Automatic vectorization (AKA processing batches of data).
- Automatic parallelization across multiple accelerators (GPUs).
But the main thing really is the compiler, so let's get into that.
XLA and compilation¶
JAX stands for "Just Autograd and XLA". XLA is a compiler that tries to produce high-performance code for linear algebra computations.
There are two main ways to convert source code into machine code that can run on a physical device:
- Interpretation. The source code is translated to machine code and executed one statement at a time. This is how Python and PyTorch work (ignoring torch.compile).
- Compilation. A big chunk of source code is translated to machine code once and then the machine code can be executed repeatedly. This is how C works.
JIT stands for "Just In Time" and it refers to a strategy for compiling code only when it is first run (contrasted with compilers for languages like C that compile ahead of time and produce a binary that may never run).
You use jax.jit
to indicate you want some code to be compiled (though it will not be compiled until it is run, since it's JIT!).
You don't have to use jax.jit
, but if you're not you probably should just use PyTorch instead because jax will be slower and less flexible.
Let's take a look!
try:
import equinox as eqx
except ModuleNotFoundError:
!pip install -q equinox
import equinox as eqx
import jax
import jax.numpy as jnp
w = 12.0
def my_func(x: jax.Array, y: jax.Array):
print(f"calling my_func({x, y}) with w = {w}")
return x + w * y
my_func(jnp.ones(2,), 2*jnp.ones(2,))
calling my_func((Array([1., 1.], dtype=float32), Array([2., 2.], dtype=float32))) with w = 12.0
Array([25., 25.], dtype=float32)
Now let's compile it
my_func_jit = jax.jit(my_func)
my_func_jit(jnp.ones(2,), 2*jnp.ones(2,))
calling my_func((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)) with w = 12.0
Array([25., 25.], dtype=float32)
When calling my_func_jit
, its inputs are now Traced
rather than Array
. This is because the JAX compiler works by first interpreting the code in regular Python mode, but replaces the array arguments with Traced objects. During tracing, all of the array operations are recorded and only those array operations are passed to the next level of compilation.
Let's execute my_func_jit again:
my_func_jit(jnp.ones(2,), 2*jnp.ones(2,))
Array([25., 25.], dtype=float32)
You can see that the non-array operations are not captured by the trace by noticing that the print
statement does not execute. This is showing that my_func_jit
has already been compiled and is not being recompiled (since the arguments are of the same shape and dtype).
Control flow¶
Python control flow (ifs, loops) is captured only during tracing. If you want the control flow to be part of the compiled function, you need to use jax.lax
control flow functions instead of Python built-ins.
For example this python control flow fail when being JIT compiled:
def python_control_flow(x):
if x % 2:
return x
return x + 1
python_control_flow_jit = jax.jit(python_control_flow)
python_control_flow_jit(1)
--------------------------------------------------------------------------- TracerBoolConversionError Traceback (most recent call last) <ipython-input-9-92ad3ab87490> in <cell line: 8>() 6 python_control_flow_jit = jax.jit(python_control_flow) 7 ----> 8 print(python_control_flow_jit(1)) 9 print(python_control_flow_jit(2)) [... skipping hidden 11 frame] <ipython-input-9-92ad3ab87490> in python_control_flow(x) 1 def python_control_flow(x): ----> 2 if x % 2 == 0: 3 return x 4 return x + 1 5 [... skipping hidden 1 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/core.py in error(self, arg) 1490 if fun is bool: 1491 def error(self, arg): -> 1492 raise TracerBoolConversionError(arg) 1493 elif fun in (hex, oct, operator.index): 1494 def error(self, arg): TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. The error occurred while tracing the function python_control_flow at <ipython-input-9-92ad3ab87490>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
But this logically-equivalent use of the jax.lax
control-flow function succeeds:
def jax_control_flow(x):
return jax.lax.select(x % 2, x, x + 1)
jax_control_flow_jit = jax.jit(jax_control_flow)
jax_control_flow_jit(1)
Array(1, dtype=int32, weak_type=True)
Static arguments¶
In JAX, a static argument to a function is one that will not change after the function has been compiled. That is, calling the function with a different value of a static argument will trigger recompilation.
Python control-flow on static arguments is allowed, but you should only do this if the argument is some sort of hyperparameter or configuration that doesn't change throughout a computation, because compilation is slow.
def python_control_flow_with_static(x: jax.Array, identity: bool) -> jax.Array:
print("identity:", identity)
if identity:
return x
return x + 1
python_control_flow_with_static_jit = jax.jit(python_control_flow_with_static, static_argnums=(1,))
print(python_control_flow_with_static_jit(1, True))
print(python_control_flow_with_static_jit(2, True))
print(python_control_flow_with_static_jit(1, False))
print(python_control_flow_with_static_jit(2, False))
identity: True 1 2 identity: False 2 3
You can also use static arguments for other things, e.g. callables. If you forget to do this, the error message can be a bit mystifying at first:
def apply_to(fn, x):
return fn(x)
jax.jit(apply_to)(jax.nn.relu, jnp.array([-1, 0, 1]))
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-15-0ed9ab30caa5> in <cell line: 4>() 2 return fn(x) 3 ----> 4 apply_to_jit = jax.jit(apply_to)(jax.nn.relu, jnp.array([-1, 0, 1])) [... skipping hidden 5 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py in _shaped_abstractify_slow(x) 586 dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) 587 else: --> 588 raise TypeError( 589 f"Cannot interpret value of type {type(x)} as an abstract array; it " 590 "does not have a dtype attribute") TypeError: Cannot interpret value of type <class 'jax._src.custom_derivatives.custom_jvp'> as an abstract array; it does not have a dtype attribute
The fix is to set the callable argument as static:
jax.jit(apply_to, static_argnums=(0,))(jax.nn.relu, jnp.array([-1, 0, 1]))
Array([0, 0, 1], dtype=int32)
Closures¶
Loosely speaking, a closure is a combination of a function and all of the things it references. Let's take a look at our example from above again:
w = 12.0
def my_func(x: jax.Array, y: jax.Array):
print(f"calling my_func({x, y}) with w = {w}")
return x + w * y
my_func_jit = jax.jit(my_func)
my_func_jit(jnp.ones(2,), 2*jnp.ones(2,))
In this case the closure of my_func
includes w
. After tracing, variables like w
that are not arguments to the function are baked-in to the compiled function and changing them has no effect on an already compiled function. For example:
w = 3.0
my_func_jit(jnp.ones(2,), 2*jnp.ones(2,))
Array([25., 25.], dtype=float32)
If we compile again, though we can see the new value of w
is picked up. We can force recompilation by passing in different shapes or dtypes for the arguments.
my_func_jit_again = jax.jit(my_func)
my_func_jit_again(jnp.ones(3,), 2*jnp.ones(3,))
calling my_func((Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>)) with w = 3.0
Array([7., 7., 7.], dtype=float32)
Pytrees¶
JAX uses the term "pytree" to refer to a tree-like structure built out of Python objects. Basically any Python object is in theory a valid pytree as long as it doesn't contain reference cycles.
Python classes have to be registered in order to be treated as pytrees, but the built-in container types (lists, dicts, tuples) are all preregistered.
JAX functions that take arrays typically also work with pytrees as well, in sort of magical ways that are very powerful.
For example automatic vectorization via jax.vmap works: if we vmap a function that returns a pytree, the return value of the vmapped function will have the same tree structure, but the array leaves will have the same dimension as the input.
def return_dict(x):
assert x.shape == (), "Only scalars supported"
return {"sin": jnp.sin(x), "cos": jnp.cos(x)}
return_dict_vmap = jax.vmap(return_dict)
return_dict_vmap(jnp.array([0.5, 1.0]))
{'cos': Array([0.87758255, 0.5403023 ], dtype=float32), 'sin': Array([0.47942555, 0.84147096], dtype=float32)}
Equinox¶
Equinox is a neural network library for JAX. It mainly consists of:
- A
Module
base class that is a pytree and behaves like a dataclass. - Some common neural network layers (e.g. linear, convolutional, attention) and helpers (e.g. Sequential).
- Some helper functions for manipulating pytrees.
- Some wrappers to make it easier to use JAX transformations (
jax.jit
,jax.grad
, etc) with Modules or other pytrees.
Equinox's own documentation is pretty good, but let me take a stab at explaining the main value that equinox provides.
JAX's transformations (e.g. jit
, grad
) work on Arrays or pytrees whose leaves are Arrays. Equinox Modules typically hold a mix of Arrays and other objects at their leaves (e.g. a reference to a callable function like jax.nn.relu
). If you pass a pytree with a non-array leaf into such a transformation, JAX will complain. Equinox lets you keep your arrays and computation together in one pytree (Module), and provides tools to let you mark all of the non-Array parts of the Module as static before passing it in to a transformation like jit
.
That's it for now! Let me know if you have questions or corrections.