Step 15: JAX for high-performance GPU computing#


In the upcoming steps, we will explore JAX, a powerful library for Python, and demonstrate how it can be utilized to either modify existing code or develop new code optimized for efficient GPU computing. This step will specifically focus on introducing the fundamental functions of the JAX library.

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance computational research. Here is a short video to introduce JAX: https://www.youtube.com/watch?v=SFKEQs_Hu2c&t=94s

JAX = JIT + AutoGrad + XLA

JIT: just-in-time (JIT) compilation

AutoGrad: automatically differentiation

XLA: Accelerated Linear Algebra

To effectively test the JAX code on GPUs, it’s advisable to utilize Google Colab, which offers access to cloud-based GPU resources. This approach is particularly beneficial if you don’t have local GPU capabilities. However, if you do have access to local GPU resources, they would also serve well for this purpose.

import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax import random
np.zeros(10)
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
jnp.zeros(10)
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

Multiplying Matrices#

We’ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers.

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]

Let’s dive right in and multiply two big matrices.

size = 3000
x = random.normal(key, (size,size), dtype = jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
6.35 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

We added that block_until_ready because JAX uses asynchronous execution by default (see Asynchronous dispatch).

y = np.random.rand(size, size)
%timeit np.dot(y, y.T)
1.13 s ± 669 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

You can see JAX numpy is much faster than numpy on computing matrix multiplication!

JAX NumPy functions work on regular NumPy arrays.

from jax._src.api import block_until_ready
x = np.random.normal(size = (size,size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
36.7 ms ± 375 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

That’s slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using device_put().

from jax import device_put

x = np.random.normal(size = (size,size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
4.6 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

The output of device_put() still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc. The behavior of device_put() is equivalent to the function jit(lambda x: x), but it’s faster.

JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:

jit(), for speeding up your code

grad(), for taking derivatives

vmap(), for automatic vectorization or batching.

Let’s go over these, one-by-one. We’ll also end up composing these in interesting ways.

Using jit() to speed up functions#

JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the @jit (just-in-time compile) decorator to compile multiple operations together using XLA (Accelerated Linear Algebra). Let’s try that.

@jit
def selu(x, alpha = 1.67, lmbda = 1.05):
  return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (10000000,))
%timeit selu(x).block_until_ready()
246 µs ± 8.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The selu() function is called 10 million times. In this senario (a function is called many many times), we can speed it up with @jit, which will jit-compile the first time selu is called and will be cached thereafter.

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
258 µs ± 4.11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Another way to use jit is:

@jit
def selu(x, alpha = 1.67, lmbda = 1.05):
  return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (10000000,))
%timeit selu(x).block_until_ready()
258 µs ± 6.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Taking derivatives with grad()#

In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just like in Autograd, you can compute gradients with the grad() function.

With this configuration,config.update(“jax_enable_x64”, True), JAX will now use float64 precision by default. Please note that the config.update(“jax_enable_x64”, True) command should be called at the beginning of your script before importing JAX or modules that use JAX.

Also, remember that using double precision comes with increased memory usage and computational cost, especially on GPUs.

from jax.config import config
config.update("jax_enable_x64", True)

from numpy.matrixlib import defmatrix

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(100.)
y_small = jnp.linspace(0,3,100)

derivative_fn = grad(sum_logistic)  # define a function which is the gradient of sum_logistic
%timeit derivative_fn(x_small)
print(derivative_fn(x_small))
<ipython-input-12-bce2302fe8d6>:1: DeprecationWarning: Accessing jax.config via the jax.config submodule is deprecated.
  from jax.config import config


7.21 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
[2.50000000e-01 1.96611933e-01 1.04993585e-01 4.51766597e-02
 1.76627062e-02 6.64805667e-03 2.46650929e-03 9.10221180e-04
 3.35237671e-04 1.23379350e-04 4.53958077e-05 1.67011429e-05
 6.14413685e-06 2.26031919e-06 8.31527336e-07 3.05902133e-07
 1.12535149e-07 4.13993738e-08 1.52299793e-08 5.60279637e-09
 2.06115361e-09 7.58256042e-10 2.78946809e-10 1.02618796e-10
 3.77513454e-11 1.38879439e-11 5.10908903e-12 1.87952882e-12
 6.91440011e-13 2.54366565e-13 9.35762297e-14 3.44247711e-14
 1.26641655e-14 4.65888615e-15 1.71390843e-15 6.30511676e-16
 2.31952283e-16 8.53304763e-17 3.13913279e-17 1.15482242e-17
 4.24835426e-18 1.56288219e-18 5.74952226e-19 2.11513104e-19
 7.78113224e-20 2.86251858e-20 1.05306174e-20 3.87399763e-21
 1.42516408e-21 5.24288566e-22 1.92874985e-22 7.09547416e-23
 2.61027907e-23 9.60268005e-24 3.53262857e-24 1.29958143e-24
 4.78089288e-25 1.75879220e-25 6.47023493e-26 2.38026641e-26
 8.75651076e-27 3.22134029e-27 1.18506486e-27 4.35961000e-28
 1.60381089e-28 5.90009054e-29 2.17052201e-29 7.98490425e-30
 2.93748211e-30 1.08063928e-30 3.97544974e-31 1.46248623e-31
 5.38018616e-32 1.97925988e-32 7.28129018e-33 2.67863696e-33
 9.85415469e-34 3.62514092e-34 1.33361482e-34 4.90609473e-35
 1.80485139e-35 6.63967720e-36 2.44260074e-36 8.98582594e-37
 3.30570063e-37 1.21609930e-37 4.47377931e-38 1.64581143e-38
 6.05460190e-39 2.22736356e-39 8.19401262e-40 3.01440879e-40
 1.10893902e-40 4.07955867e-41 1.50078576e-41 5.52108228e-42
 2.03109266e-42 7.47197234e-43 2.74878501e-43 1.01122149e-43]

Let’s verify with finite differences that our result is correct.

def first_finite_differences(f, x):
  eps = 1e-6
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

%timeit first_finite_differences(sum_logistic, x_small)
print(first_finite_differences(sum_logistic, x_small))
# x = jnp.arange(10)
# for v in jnp.eye(len(x)):
#   print(v)
# jnp.array([v for v in jnp.eye(len(x))])

208 ms ± 13.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[2.49999999e-01 1.96611936e-01 1.04993582e-01 4.51766624e-02
 1.76627069e-02 6.64805810e-03 2.46650700e-03 9.10219455e-04
 3.35241168e-04 1.23385746e-04 4.53965754e-05 1.66977543e-05
 6.14619466e-06 2.25952590e-06 8.31335001e-07 3.05533376e-07
 1.13686838e-07 4.26325641e-08 1.42108547e-08 7.10542736e-09
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]

The results are almost the same, but you can see the bulit-in function grad() is much faster and has higher accuracy.

Taking derivatives is as easy as calling grad(). grad() and jit() compose and can be mixed arbitrarily. In the above example we jitted sum_logistic and then took its derivative. We can go further:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.03532558051623561

For more advanced autodiff, you can use jax.vjp() for reverse-mode vector-Jacobian products and jax.jvp() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

Auto-vectorization with vmap()#

JAX has one more transformation in its API that you might find useful: vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with jit(), it can be just as fast as adding the batch dimensions by hand.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

mat = random.normal(key, (150,100))
batched_x = random.normal(key, (10,100))

def apply_matrix(v):
  return jnp.dot(mat, v)

A = random.normal(key, (4,4))
x = jnp.arange(4)
y = jnp.arange(4)
z = jnp.stack([x,y],1) # if axis = 0, jnp.dot(A,z) will have an error
print(x)
print(A)
print(jnp.dot(A,x))  # jnp.dot is matrix multiplication. If x is a vector, it represents Ax
print(A*x)           # A*x: A = [a1,a2,a3,a4] A*x = [a1*x1, a2*x2, a3*x3, a4*x4]

print(z)
print(jnp.dot(A,z))
print(A @ z)         # A @ z is matrix multiplication
[0 1 2 3]
[[-0.53389115  0.84179134  0.81155729  0.05308707]
 [ 0.72478811 -0.53911566 -0.21932127  0.5509203 ]
 [ 0.16972549  1.19717228 -1.06094203  0.28213284]
 [-1.05431656  1.01875438 -0.42167228 -2.58898201]]
[ 2.62416712  0.67500269 -0.07831328 -7.59153623]
[[-0.          0.84179134  1.62311457  0.15926121]
 [ 0.         -0.53911566 -0.43864254  1.65276089]
 [ 0.          1.19717228 -2.12188407  0.84639851]
 [-0.          1.01875438 -0.84334457 -7.76694604]]
[[0 0]
 [1 1]
 [2 2]
 [3 3]]
[[ 2.62416712  2.62416712]
 [ 0.67500269  0.67500269]
 [-0.07831328 -0.07831328]
 [-7.59153623 -7.59153623]]
[[ 2.62416712  2.62416712]
 [ 0.67500269  0.67500269]
 [-0.07831328 -0.07831328]
 [-7.59153623 -7.59153623]]
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

#apply_matrix(batched_x) #result in an error

Given a function such as apply_matrix, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched], 0)

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
4.1 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

We know how to batch this operation manually. In this case, jnp.dot handles extra batch dimensions transparently.

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
131 µs ± 28.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

However, suppose we had a more complicated function without batching support. We can use vmap() to add batching support automatically.

@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
147 µs ± 56.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Of course, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation.

This is just a taste of what JAX can do. We’re really excited to see what you do with it!

How to Think in JAX#

JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.

JAX vs. NumPy#

Key Concepts:#

JAX provides a NumPy-inspired interface for convenience.

Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.

Unlike NumPy arrays, JAX arrays are always immutable.

NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides jax.numpy which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with numpy can be done with jax.numpy.

NumPy, lax & XLA: JAX API layering#

Key Concepts:#

jax.numpy is a high-level wrapper that provides a familiar interface.

jax.lax is a lower-level API that is stricter and often more powerful.

All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.

If you look at the source of jax.numpy, you’ll see that all the operations are eventually expressed in terms of functions defined in jax.lax. You can think of jax.lax as a stricter, but often more powerful, API for working with multi-dimensional arrays.

To JIT or not to JIT#

Key Concepts:#

By default JAX executes operations one at a time, in sequence.

Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.

As a general rule of thumb, apply JIT (Just-In-Time) compilation only to functions that will be executed repeatedly in the program.

I encourage exploring further tutorials on JAX available on their official website to deepen your understanding of its features and usage tips. However, for now, you have sufficient knowledge to proceed with converting our existing code to a JAX-enhanced version.