jax: Initialize an array from a function
(I’m writing this so that I can dig it up later when I’ve inevitably
forgotten the trick in half a year.)
You want to initialize an \(N \times M\) array in jax
. You already have a routine for computing the desired value at index (i,j)
. Normally you might use a nested for-loop like so:
import jax
import jax.numpy as jnp
N,M = 8,5 # Or whatever
def f(i,j):
return 2**i + 3**j # For example
ary = jnp.zeros((N,M))
for i in range(N):
for j in range(M):
ary[i,j] = f(i,j)
Alas, jax.jit
will see only the unrolled loop, so compilation is guaranteed
to be horrendously slow.
As usual, the loops may be profitably replaced with calls to jnp.vectorize
. But
jnp.vectorize(f)
expects to be called on two arrays. They can be generated by
jnp.indices
, which takes a desired shape, and outputs arrays listing their
own indices. As an example, after a call a,b = jnp.indices((3,5))
, a
is an
array of shape (3,5)
obeying a[i,j]==i
, and b
is an array of shape
(3,5)
obeying a[i,j]==j
.
Putting it all together (as a one-liner):
import jax
import jax.numpy as jnp
N,M = 8,5 # Or whatever
def f(i,j):
return 2**i + 3**j # For example
ary = jnp.vectorize(f)(*jnp.indices((N,M)))