import ComparisonTable from ’../../components/ComparisonTable.astro’;
PyTorch and JAX are both differentiable computing frameworks used for deep learning. PyTorch is the dominant research and production framework. JAX is Google DeepMind’s framework that powers models like Gemini — functional, composable, and XLA-compiled.
Quick Verdict
Choose PyTorch if: Building most ML models, deploying to production, using Hugging Face, or working in a team where interoperability matters.
Choose JAX if: Doing research on novel architectures, using TPUs, working in functional programming style, or working in a Google/DeepMind research context.
Framework Philosophy
<ComparisonTable headers={[“Dimension”, “PyTorch”, “JAX”]} rows={[ [“Programming model”, “Eager execution (pythonic)”, “Functional + JIT (jax.jit)”], [“Auto-diff”, “torch.autograd”, “jax.grad, jax.jacobian”], [“Hardware”, “GPU, CPU, MPS (Apple)”, “GPU, CPU, TPU (native)”], [“JIT compiler”, “torch.compile (TorchDynamo)”, “XLA (always-on via jax.jit)”], [“Distributed”, “torch.distributed, FSDP”, “pmap, shard_map”], [“Ecosystem”, “Huge (Hugging Face, Lightning)”, “Growing (Flax, Optax, Equinox)”], [“Production serving”, “TorchServe, ONNX”, “Less mature”], [“Debugging”, “Standard Python debugging”, “Harder (JIT traces)”], [“State management”, “Mutable (nn.Module)”, “Functional (no implicit state)”], [“Community”, “Very large”, “Large (academia + Google)”], ]} />
Neural Network Definition
PyTorch — OOP with nn.Module:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
# Self-attention with residual
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + attn_output)
# Feed-forward with residual
x = self.norm2(x + self.feed_forward(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size: int, d_model: int = 512, n_heads: int = 8,
n_layers: int = 6, max_seq_len: int = 1024):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_model * 4)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
B, T = tokens.shape
positions = torch.arange(T, device=tokens.device)
x = self.embedding(tokens) + self.pos_embedding(positions)
for layer in self.layers:
x = layer(x)
return self.lm_head(self.norm(x))
# Training loop
model = GPT(vocab_size=50257).cuda()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
for batch in dataloader:
tokens = batch['input_ids'].cuda()
targets = batch['labels'].cuda()
logits = model(tokens)
loss = F.cross_entropy(logits.view(-1, 50257), targets.view(-1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
JAX — functional with Flax:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
class TransformerBlock(nn.Module):
d_model: int
n_heads: int
d_ff: int
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x, deterministic: bool = True):
# Self-attention with residual
attention_output = nn.MultiHeadDotProductAttention(
num_heads=self.n_heads,
)(x, x)
attention_output = nn.Dropout(rate=self.dropout_rate)(
attention_output, deterministic=deterministic
)
x = nn.LayerNorm()(x + attention_output)
# Feed-forward with residual
ff = nn.Dense(self.d_ff)(x)
ff = nn.gelu(ff)
ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=deterministic)
ff = nn.Dense(self.d_model)(ff)
ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=deterministic)
x = nn.LayerNorm()(x + ff)
return x
class GPT(nn.Module):
vocab_size: int
d_model: int = 512
n_heads: int = 8
n_layers: int = 6
max_seq_len: int = 1024
@nn.compact
def __call__(self, tokens, deterministic: bool = True):
B, T = tokens.shape
positions = jnp.arange(T)
x = nn.Embed(self.vocab_size, self.d_model)(tokens)
x += nn.Embed(self.max_seq_len, self.d_model)(positions)
for _ in range(self.n_layers):
x = TransformerBlock(self.d_model, self.n_heads, self.d_model * 4)(
x, deterministic=deterministic
)
x = nn.LayerNorm()(x)
return nn.Dense(self.vocab_size)(x)
# JAX: Parameters are explicit (no implicit state)
model = GPT(vocab_size=50257)
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 10), dtype=jnp.int32))
# Optimizer (via Optax)
tx = optax.adamw(learning_rate=3e-4, weight_decay=0.1)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
)
# Functional training step — JIT-compiled
@jax.jit
def train_step(state, batch, dropout_key):
def loss_fn(params):
logits = model.apply(
params, batch['input_ids'],
deterministic=False,
rngs={'dropout': dropout_key}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits.reshape(-1, 50257),
batch['labels'].reshape(-1)
).mean()
return loss
loss, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
Key difference: JAX has no implicit state. Parameters are passed explicitly everywhere — functional programming style.
Automatic Differentiation
PyTorch — imperative autograd:
# Simple gradient computation
x = torch.tensor([3.0], requires_grad=True)
y = x ** 2 + 2 * x + 1
y.backward()
print(x.grad) # dy/dx = 2x + 2 = 8.0
# Higher-order gradients
from torch.autograd.functional import jacobian, hessian
def loss(params, x):
return (params @ x).sum()
# Jacobian
J = jacobian(lambda p: loss(p, x_data), params)
# Gradient of gradient
grad_fn = torch.autograd.grad(loss_value, params, create_graph=True)
second_order = torch.autograd.grad(grad_fn[0].sum(), params)
JAX — composable transforms:
import jax
# jax.grad: automatic differentiation
f = lambda x: x ** 2 + 2 * x + 1
df = jax.grad(f)
print(df(3.0)) # 8.0
# Higher-order: compose grad with itself
d2f = jax.grad(jax.grad(f))
print(d2f(3.0)) # 2.0 (second derivative)
# Jacobian and Hessian
jac = jax.jacobian(vector_fn)(x)
hess = jax.hessian(scalar_fn)(x)
# grad and value simultaneously
val, grads = jax.value_and_grad(loss_fn)(params)
# Gradient of a gradient (meta-learning use case)
inner_grad = jax.grad(inner_loss)(fast_weights)
outer_grad = jax.grad(lambda w: outer_loss(w, inner_grad))(slow_weights)
JAX’s composable transforms (jax.grad, jax.jit, jax.vmap, jax.pmap) are more elegant for research code requiring meta-learning, second-order optimization, or custom gradient operations.
Parallelism and Distribution
PyTorch:
# Data parallel (single node, multi-GPU)
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[local_rank])
# Fully Sharded Data Parallel (large models)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=CPUOffload(offload_params=True),
)
# torch.compile for optimization
model = torch.compile(model)
JAX:
# pmap: parallel map across devices
@jax.pmap
def train_step(state, batch):
# Runs on each device independently
# Gradients automatically summed across devices
...
# Replicate state across all devices
state = jax.device_put_replicated(state, jax.devices())
# vmap: vectorize over batch dimension
# (used for per-example gradient computation, e.g., DP-SGD)
per_example_grads = jax.vmap(jax.grad(single_example_loss))(params, batch)
JAX’s pmap and TPU support are native — Google’s models are trained on TPUs with JAX. PyTorch’s TPU support is through PyTorch XLA (works but less optimized).
Ecosystem Comparison
PyTorch ecosystem:
Training:
- PyTorch Lightning (training loops)
- Hugging Face Trainer (transformers)
- accelerate (multi-device/multi-node)
Models:
- Hugging Face Transformers (10,000+ models)
- torchvision, torchaudio, torchtext
Production:
- TorchServe (model serving)
- ONNX export (deploy anywhere)
- TorchScript (serialization)
- torch.export (2.x production export)
Monitoring:
- Weights & Biases, MLflow (native integrations)
JAX ecosystem:
Neural networks:
- Flax (Google DeepMind's NN library)
- Equinox (functional PyTrees approach)
- Haiku (DeepMind's modular NN)
Optimization:
- Optax (gradient transformers and optimizers)
Scientific computing:
- jax.scipy, jax.numpy (NumPy-compatible)
- Diffrax (differential equations)
LLM training:
- MaxText (Google's LLM training)
- Levanter
When to Choose Each
Choose PyTorch:
- Production ML system
- Using Hugging Face models (Transformers, Diffusers)
- Team not specialized in functional programming
- Need mature deployment tooling
- Most academic and industry research today
Choose JAX:
- Research on novel gradient-based algorithms
- TPU training (native and optimal)
- Functional programming style preferred
- Meta-learning, implicit differentiation, custom derivatives
- Working in Google/DeepMind research context
- Need extreme performance on custom hardware
Bottom Line
PyTorch is the practical choice for almost everyone — its ecosystem dominance, Hugging Face integration, and production tooling make it the clear default. JAX is the research choice for teams working at the frontier of ML methods, especially those using TPUs or needing advanced gradient computation. Both are excellent tools; the gap in capabilities has narrowed significantly, but the ecosystem gap strongly favors PyTorch for most applications.