# jax: Initialize an array from a function

*(I’m writing this so that I can dig it up later when I’ve inevitably
forgotten the trick in half a year.)*

You want to initialize an \(N \times M\) array in `jax`

. You already have a routine for computing the desired value at index `(i,j)`

. Normally you might use a nested for-loop like so:

```
import jax
import jax.numpy as jnp
N,M = 8,5 # Or whatever
def f(i,j):
return 2**i + 3**j # For example
ary = jnp.zeros((N,M))
for i in range(N):
for j in range(M):
ary[i,j] = f(i,j)
```

Alas, `jax.jit`

will see only the unrolled loop, so compilation is guaranteed
to be horrendously slow.

As usual, the loops may be profitably replaced with calls to `jnp.vectorize`

. But
`jnp.vectorize(f)`

expects to be called on two arrays. They can be generated by
`jnp.indices`

, which takes a desired shape, and outputs arrays listing their
own indices. As an example, after a call `a,b = jnp.indices((3,5))`

, `a`

is an
array of shape `(3,5)`

obeying `a[i,j]==i`

, and `b`

is an array of shape
`(3,5)`

obeying `a[i,j]==j`

.

Putting it all together (as a one-liner):

```
import jax
import jax.numpy as jnp
N,M = 8,5 # Or whatever
def f(i,j):
return 2**i + 3**j # For example
ary = jnp.vectorize(f)(*jnp.indices((N,M)))
```