jax: Initialize an array from a function
(I’m writing this so that I can dig it up later when I’ve inevitably
forgotten the trick in half a year.)
You want to initialize an \(N \times M\) array in jax. You already have a routine for computing the desired value at index (i,j). Normally you might use a nested for-loop like so:
import jax
import jax.numpy as jnp
N,M = 8,5 # Or whatever
def f(i,j):
return 2**i + 3**j # For example
ary = jnp.zeros((N,M))
for i in range(N):
for j in range(M):
ary[i,j] = f(i,j)
Alas, jax.jit will see only the unrolled loop, so compilation is guaranteed
to be horrendously slow.
As usual, the loops may be profitably replaced with calls to jnp.vectorize. But
jnp.vectorize(f) expects to be called on two arrays. They can be generated by
jnp.indices, which takes a desired shape, and outputs arrays listing their
own indices. As an example, after a call a,b = jnp.indices((3,5)), a is an
array of shape (3,5) obeying a[i,j]==i, and b is an array of shape
(3,5) obeying a[i,j]==j.
Putting it all together (as a one-liner):
import jax
import jax.numpy as jnp
N,M = 8,5 # Or whatever
def f(i,j):
return 2**i + 3**j # For example
ary = jnp.vectorize(f)(*jnp.indices((N,M)))