
Neural Cellular Automata (Based on implemented in Jax (Flax)


from source

git clone
cd jax-nca
python install

from PYPI

pip install jax-nca

How do NCAs work?

For more information, view the awesome article -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020

Image below describes a single update step:

Why Jax?

<b> Note: This project served as a nice introduction to jax, so its performance can probably be improved </b>

NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with jax.lax.scan and jax.jit (

Instead of writing the nca growth process as

def multi_step(params, nca, current_state, num_steps):
    # params: parameters for NCA
    # nca: Flax Module describing NCA
    # current_state: Current NCA state
    # num_steps: number of steps to run

    for i in range(num_steps):
        current_state = nca.apply(params, current_state)
    return current_state

We can write this with jax.lax.scan

def multi_step(params, nca, current_state, num_steps):
    # params: parameters for NCA
    # nca: Flax Module describing NCA
    # current_state: Current NCA state
    # num_steps: number of steps to run

    def forward(carry, inp):
        carry = nca.apply({"params": params}, carry)
        return carry, carry

    final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
    return final_state

The actual multi_step implementation can be found here:


See notebooks/Gecko.ipynb for a full example

<b> Currently there's a bug with the stochastic update, so only cell_fire_rate = 1.0 works at the moment </b>

Creating and using NCA

class NCA(nn.Module):
    num_hidden_channels: int
    num_target_channels: int = 3
    alpha_living_threshold: float = 0.1
    cell_fire_rate: float = 1.0
    trainable_perception: bool = False
    alpha: float = 1.0

        num_hidden_channels: Number of hidden channels for each cell to use
        num_target_channels: Number of target channels to be used
        alpha_living_threshold: threshold to determine whether a cell lives or dies
        cell_fire_rate: probability that a cell receives an update per step
        trainable_perception: if true, instead of using sobel filters use a trainable conv net
        alpha: scalar value to be multiplied to updates

from jax_nca.nca import NCA

# usage
nca = NCA(
    num_hidden_channels = 16, 
    num_target_channels = 3,
    trainable_perception = False,
    cell_fire_rate = 1.0,
    alpha_living_threshold = 0.1

nca_seed = nca.create_seed(
    nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
rng = jax.random.PRNGKey(0)
params = = nca.init(rng, nca_seed, rng)["params"]
update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))

# multi step

final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)

To train the NCA

from jax_nca.dataset import ImageDataset
from jax_nca.trainer import EmojiTrainer

dataset = ImageDataset(emoji='🦎', img_size=64)

nca = NCA(
    num_hidden_channels = 16, 
    num_target_channels = 3,
    trainable_perception = False,
    cell_fire_rate = 1.0,
    alpha_living_threshold = 0.1

trainer = EmojiTrainer(dataset, nca, n_damage=0)

trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)

# to access train state:

state = trainer.state

# save, "saved_params")

# load params
loaded_params = nca.load("saved_params")