Archive

# jax: Initialize an array from a function

(I’m writing this so that I can dig it up later when I’ve inevitably forgotten the trick in half a year.)

You want to initialize an $N \times M$ array in jax. You already have a routine for computing the desired value at index (i,j). Normally you might use a nested for-loop like so:

import jax
import jax.numpy as jnp

N,M = 8,5  # Or whatever

def f(i,j):
return 2**i + 3**j  # For example

ary = jnp.zeros((N,M))
for i in range(N):
for j in range(M):
ary[i,j] = f(i,j)


Alas, jax.jit will see only the unrolled loop, so compilation is guaranteed to be horrendously slow.

As usual, the loops may be profitably replaced with calls to jnp.vectorize. But jnp.vectorize(f) expects to be called on two arrays. They can be generated by jnp.indices, which takes a desired shape, and outputs arrays listing their own indices. As an example, after a call a,b = jnp.indices((3,5)), a is an array of shape (3,5) obeying a[i,j]==i, and b is an array of shape (3,5) obeying a[i,j]==j.

Putting it all together (as a one-liner):

import jax
import jax.numpy as jnp

N,M = 8,5  # Or whatever

def f(i,j):
return 2**i + 3**j  # For example

ary = jnp.vectorize(f)(*jnp.indices((N,M)))