Building a Modular Neural Network with JAX, Flax, and Optax
This guide walks you through constructing and training a sophisticated neural network using JAX, Flax, and Optax, emphasizing modularity and computational efficiency. We start by architecting a deep model that fuses residual connections with self-attention layers to capture rich spatial and contextual features. Next, we implement advanced optimization techniques including learning rate warmup, cosine decay scheduling, gradient norm clipping, and adaptive weight decay. Leveraging JAX’s powerful transformations like jit, grad, and vmap, we ensure accelerated execution and seamless scaling across hardware accelerators.
Setting Up the Environment and Dependencies
!pip install jax jaxlib flax optax matplotlib
import jax
import jax.numpy as jnp
from jax import random, jit, grad, vmap
import flax.linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
from typing import Any
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
We begin by installing the necessary libraries and importing essential modules for numerical computation, model definition, optimization, and visualization. Confirming the available devices ensures that JAX is configured to utilize GPUs or TPUs if present, which is critical for efficient training.
Designing the Neural Network Architecture
Our model integrates residual blocks with a multi-head self-attention mechanism to enhance feature representation. This hybrid design allows the network to learn both local spatial patterns and global dependencies, improving generalization across diverse datasets.
class MultiHeadSelfAttention(nn.Module):
num_heads: int
embed_dim: int
@nn.compact
def __call__(self, x):
batch_size, seq_len, dim = x.shape
head_dim = dim // self.num_heads
qkv = nn.Dense(3 * dim)(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, head_dim)
q, k, v = jnp.split(qkv, 3, axis=2)
q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)
attn_logits = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
attn_weights = jax.nn.softmax(attn_logits, axis=-1)
attn_output = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)
attn_output = attn_output.reshape(batch_size, seq_len, dim)
return nn.Dense(dim)(attn_output)
class ResidualConvBlock(nn.Module):
channels: int
@nn.compact
def __call__(self, x, training: bool = True):
shortcut = x
x = nn.Conv(self.channels, kernel_size=(3, 3), padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.Conv(self.channels, kernel_size=(3, 3), padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not training)(x)
if shortcut.shape[-1] != self.channels:
shortcut = nn.Conv(self.channels, kernel_size=(1, 1))(shortcut)
return nn.relu(x + shortcut)
class CustomCNN(nn.Module):
num_classes: int = 10
@nn.compact
def __call__(self, x, training: bool = True):
x = nn.Conv(32, (3, 3), padding='SAME')(x)
x = nn.relu(x)
x = ResidualConvBlock(64)(x, training)
x = ResidualConvBlock(64)(x, training)
x = nn.max_pool(x, (2, 2), strides=(2, 2))
x = ResidualConvBlock(128)(x, training)
x = ResidualConvBlock(128)(x, training)
x = jnp.mean(x, axis=(1, 2)) # Global average pooling
x = x[:, None, :] # Add sequence dimension for attention
x = MultiHeadSelfAttention(num_heads=4, embed_dim=128)(x)
x = x.squeeze(1)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = nn.Dense(self.num_classes)(x)
return x
Configuring the Optimizer and Learning Rate Schedule
To promote stable and efficient training, we implement a learning rate schedule that starts with a linear warmup phase followed by cosine decay. The optimizer combines AdamW with gradient clipping and weight decay to prevent exploding gradients and overfitting.
class TrainStateWithBatchStats(train_state.TrainState):
batch_stats: Any
def build_lr_schedule(base_lr=1e-3, warmup_steps=100, decay_steps=1000) -> optax.Schedule:
warmup = optax.linear_schedule(init_value=0.0, end_value=base_lr, transition_steps=warmup_steps)
decay = optax.cosine_decay_schedule(init_value=base_lr, decay_steps=decay_steps, alpha=0.1)
return optax.join_schedules([warmup, decay], boundaries=[warmup_steps])
def setup_optimizer(lr_schedule: optax.Schedule) -> optax.GradientTransformation:
return optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=lr_schedule, weight_decay=1e-4)
)
Efficient Training and Evaluation Steps with JIT Compilation
We define JIT-compiled functions for training and evaluation to maximize performance. The training step calculates gradients, updates model parameters, and manages batch normalization statistics dynamically. Evaluation computes loss and accuracy metrics to monitor model progress.
@jit
def calculate_metrics(logits, labels):
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
return {'loss': loss, 'accuracy': accuracy}
def initialize_train_state(rng, model, input_shape, lr_schedule):
variables = model.init(rng, jnp.ones(input_shape), training=False)
params = variables['params']
batch_stats = variables.get('batch_stats', {})
optimizer = setup_optimizer(lr_schedule)
return TrainStateWithBatchStats.create(apply_fn=model.apply, params=params, tx=optimizer, batch_stats=batch_stats)
@jit
def training_step(state, batch, dropout_rng):
images, labels = batch
def loss_fn(params):
vars = {'params': params, 'batch_stats': state.batch_stats}
logits, new_state = state.apply_fn(vars, images, training=True, mutable=['batch_stats'], rngs={'dropout': dropout_rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, (logits, new_state)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, new_state)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads, batch_stats=new_state['batch_stats'])
metrics = calculate_metrics(logits, labels)
return state, metrics
@jit
def evaluation_step(state, batch):
images, labels = batch
vars = {'params': state.params, 'batch_stats': state.batch_stats}
logits = state.apply_fn(vars, images, training=False)
return calculate_metrics(logits, labels)
Generating Synthetic Data and Batching
To facilitate rapid experimentation without external datasets, we create synthetic image data and corresponding labels. Data batching is implemented to feed the model in manageable chunks during training and evaluation.
def create_synthetic_dataset(rng, num_samples=1000, image_size=32):
rng_images, rng_labels = random.split(rng)
images = random.normal(rng_images, (num_samples, image_size, image_size, 3))
labels = random.randint(rng_labels, (num_samples,), 0, 10)
return images, labels
def batch_data(images, labels, batch_size=32):
total_batches = len(images) // batch_size
for i in range(total_batches):
start = i * batch_size
end = start + batch_size
yield images[start:end], labels[start:end]
Training Loop and Performance Visualization
The training routine orchestrates data generation, model initialization, and iterative optimization over multiple epochs. We track training loss and accuracy, as well as validation accuracy, plotting these metrics to visualize learning trends.
def run_training(num_epochs=5, batch_size=32):
rng = random.PRNGKey(42)
rng, data_rng, model_rng = random.split(rng, 3)
train_images, train_labels = create_synthetic_dataset(data_rng, num_samples=1000)
val_images, val_labels = create_synthetic_dataset(data_rng, num_samples=200)
model = CustomCNN(num_classes=10)
lr_schedule = build_lr_schedule(base_lr=1e-3, warmup_steps=50, decay_steps=500)
state = initialize_train_state(model_rng, model, (1, 32, 32, 3), lr_schedule)
history = {'train_loss': [], 'train_accuracy': [], 'val_accuracy': []}
print("Training started...")
for epoch in range(num_epochs):
train_metrics = []
for batch in batch_data(train_images, train_labels, batch_size):
rng, dropout_rng = random.split(rng)
state, metrics = training_step(state, batch, dropout_rng)
train_metrics.append(metrics)
train_loss = jnp.mean(jnp.array([m['loss'] for m in train_metrics]))
train_acc = jnp.mean(jnp.array([m['accuracy'] for m in train_metrics]))
val_metrics = [evaluation_step(state, batch) for batch in batch_data(val_images, val_labels, batch_size)]
val_acc = jnp.mean(jnp.array([m['accuracy'] for m in val_metrics]))
history['train_loss'].append(float(train_loss))
history['train_accuracy'].append(float(train_acc))
history['val_accuracy'].append(float(val_acc))
print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
return history, state
history, final_state = run_training(num_epochs=5)
fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(14, 5))
ax_loss.plot(history['train_loss'], label='Training Loss')
ax_loss.set_title('Loss Over Epochs')
ax_loss.set_xlabel('Epoch')
ax_loss.set_ylabel('Loss')
ax_loss.legend()
ax_loss.grid(True)
ax_acc.plot(history['train_accuracy'], label='Training Accuracy')
ax_acc.plot(history['val_accuracy'], label='Validation Accuracy')
ax_acc.set_title('Accuracy Over Epochs')
ax_acc.set_xlabel('Epoch')
ax_acc.set_ylabel('Accuracy')
ax_acc.legend()
ax_acc.grid(True)
plt.tight_layout()
plt.show()
Summary and Key Takeaways
In this tutorial, we developed a comprehensive deep learning pipeline using JAX, Flax, and Optax that balances flexibility with high performance. We explored how to build custom modules such as residual convolutional blocks and multi-head self-attention layers, implement advanced optimization techniques including AdamW with gradient clipping, and manage training state effectively with batch normalization statistics. The use of JAX’s JIT compilation and vectorization capabilities significantly accelerates training and evaluation. This framework lays a solid foundation for scaling to real-world datasets and more complex architectures, empowering you to conduct efficient and reproducible machine learning experiments.
