Our Pick PyTorch — PyTorch's dominant ecosystem, mature production tooling, and intuitive API make it the right default for most ML work. JAX's functional programming model and XLA compilation excel in research settings and custom hardware (TPUs), and powers Google DeepMind's most recent models.
PyTorch vs JAX

import ComparisonTable from ’../../components/ComparisonTable.astro’;

PyTorch and JAX are both differentiable computing frameworks used for deep learning. PyTorch is the dominant research and production framework. JAX is Google DeepMind’s framework that powers models like Gemini — functional, composable, and XLA-compiled.

Quick Verdict

Choose PyTorch if: Building most ML models, deploying to production, using Hugging Face, or working in a team where interoperability matters.

Choose JAX if: Doing research on novel architectures, using TPUs, working in functional programming style, or working in a Google/DeepMind research context.


Framework Philosophy

<ComparisonTable headers={[“Dimension”, “PyTorch”, “JAX”]} rows={[ [“Programming model”, “Eager execution (pythonic)”, “Functional + JIT (jax.jit)”], [“Auto-diff”, “torch.autograd”, “jax.grad, jax.jacobian”], [“Hardware”, “GPU, CPU, MPS (Apple)”, “GPU, CPU, TPU (native)”], [“JIT compiler”, “torch.compile (TorchDynamo)”, “XLA (always-on via jax.jit)”], [“Distributed”, “torch.distributed, FSDP”, “pmap, shard_map”], [“Ecosystem”, “Huge (Hugging Face, Lightning)”, “Growing (Flax, Optax, Equinox)”], [“Production serving”, “TorchServe, ONNX”, “Less mature”], [“Debugging”, “Standard Python debugging”, “Harder (JIT traces)”], [“State management”, “Mutable (nn.Module)”, “Functional (no implicit state)”], [“Community”, “Very large”, “Large (academia + Google)”], ]} />


Neural Network Definition

PyTorch — OOP with nn.Module:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        # Self-attention with residual
        attn_output, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + attn_output)
        # Feed-forward with residual
        x = self.norm2(x + self.feed_forward(x))
        return x


class GPT(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 512, n_heads: int = 8,
                 n_layers: int = 6, max_seq_len: int = 1024):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_model * 4)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        B, T = tokens.shape
        positions = torch.arange(T, device=tokens.device)
        x = self.embedding(tokens) + self.pos_embedding(positions)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(self.norm(x))


# Training loop
model = GPT(vocab_size=50257).cuda()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

for batch in dataloader:
    tokens = batch['input_ids'].cuda()
    targets = batch['labels'].cuda()
    
    logits = model(tokens)
    loss = F.cross_entropy(logits.view(-1, 50257), targets.view(-1))
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

JAX — functional with Flax:

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax


class TransformerBlock(nn.Module):
    d_model: int
    n_heads: int
    d_ff: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        # Self-attention with residual
        attention_output = nn.MultiHeadDotProductAttention(
            num_heads=self.n_heads,
        )(x, x)
        attention_output = nn.Dropout(rate=self.dropout_rate)(
            attention_output, deterministic=deterministic
        )
        x = nn.LayerNorm()(x + attention_output)

        # Feed-forward with residual
        ff = nn.Dense(self.d_ff)(x)
        ff = nn.gelu(ff)
        ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=deterministic)
        ff = nn.Dense(self.d_model)(ff)
        ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=deterministic)
        x = nn.LayerNorm()(x + ff)
        return x


class GPT(nn.Module):
    vocab_size: int
    d_model: int = 512
    n_heads: int = 8
    n_layers: int = 6
    max_seq_len: int = 1024

    @nn.compact
    def __call__(self, tokens, deterministic: bool = True):
        B, T = tokens.shape
        positions = jnp.arange(T)
        x = nn.Embed(self.vocab_size, self.d_model)(tokens)
        x += nn.Embed(self.max_seq_len, self.d_model)(positions)

        for _ in range(self.n_layers):
            x = TransformerBlock(self.d_model, self.n_heads, self.d_model * 4)(
                x, deterministic=deterministic
            )

        x = nn.LayerNorm()(x)
        return nn.Dense(self.vocab_size)(x)


# JAX: Parameters are explicit (no implicit state)
model = GPT(vocab_size=50257)
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 10), dtype=jnp.int32))

# Optimizer (via Optax)
tx = optax.adamw(learning_rate=3e-4, weight_decay=0.1)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx,
)


# Functional training step — JIT-compiled
@jax.jit
def train_step(state, batch, dropout_key):
    def loss_fn(params):
        logits = model.apply(
            params, batch['input_ids'],
            deterministic=False,
            rngs={'dropout': dropout_key}
        )
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits.reshape(-1, 50257),
            batch['labels'].reshape(-1)
        ).mean()
        return loss

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

Key difference: JAX has no implicit state. Parameters are passed explicitly everywhere — functional programming style.


Automatic Differentiation

PyTorch — imperative autograd:

# Simple gradient computation
x = torch.tensor([3.0], requires_grad=True)
y = x ** 2 + 2 * x + 1

y.backward()
print(x.grad)  # dy/dx = 2x + 2 = 8.0

# Higher-order gradients
from torch.autograd.functional import jacobian, hessian

def loss(params, x):
    return (params @ x).sum()

# Jacobian
J = jacobian(lambda p: loss(p, x_data), params)

# Gradient of gradient
grad_fn = torch.autograd.grad(loss_value, params, create_graph=True)
second_order = torch.autograd.grad(grad_fn[0].sum(), params)

JAX — composable transforms:

import jax

# jax.grad: automatic differentiation
f = lambda x: x ** 2 + 2 * x + 1
df = jax.grad(f)
print(df(3.0))  # 8.0

# Higher-order: compose grad with itself
d2f = jax.grad(jax.grad(f))
print(d2f(3.0))  # 2.0 (second derivative)

# Jacobian and Hessian
jac = jax.jacobian(vector_fn)(x)
hess = jax.hessian(scalar_fn)(x)

# grad and value simultaneously
val, grads = jax.value_and_grad(loss_fn)(params)

# Gradient of a gradient (meta-learning use case)
inner_grad = jax.grad(inner_loss)(fast_weights)
outer_grad = jax.grad(lambda w: outer_loss(w, inner_grad))(slow_weights)

JAX’s composable transforms (jax.grad, jax.jit, jax.vmap, jax.pmap) are more elegant for research code requiring meta-learning, second-order optimization, or custom gradient operations.


Parallelism and Distribution

PyTorch:

# Data parallel (single node, multi-GPU)
from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model, device_ids=[local_rank])

# Fully Sharded Data Parallel (large models)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    cpu_offload=CPUOffload(offload_params=True),
)

# torch.compile for optimization
model = torch.compile(model)

JAX:

# pmap: parallel map across devices
@jax.pmap
def train_step(state, batch):
    # Runs on each device independently
    # Gradients automatically summed across devices
    ...

# Replicate state across all devices
state = jax.device_put_replicated(state, jax.devices())

# vmap: vectorize over batch dimension
# (used for per-example gradient computation, e.g., DP-SGD)
per_example_grads = jax.vmap(jax.grad(single_example_loss))(params, batch)

JAX’s pmap and TPU support are native — Google’s models are trained on TPUs with JAX. PyTorch’s TPU support is through PyTorch XLA (works but less optimized).


Ecosystem Comparison

PyTorch ecosystem:

Training:
- PyTorch Lightning (training loops)
- Hugging Face Trainer (transformers)
- accelerate (multi-device/multi-node)

Models:
- Hugging Face Transformers (10,000+ models)
- torchvision, torchaudio, torchtext

Production:
- TorchServe (model serving)
- ONNX export (deploy anywhere)
- TorchScript (serialization)
- torch.export (2.x production export)

Monitoring:
- Weights & Biases, MLflow (native integrations)

JAX ecosystem:

Neural networks:
- Flax (Google DeepMind's NN library)
- Equinox (functional PyTrees approach)
- Haiku (DeepMind's modular NN)

Optimization:
- Optax (gradient transformers and optimizers)

Scientific computing:
- jax.scipy, jax.numpy (NumPy-compatible)
- Diffrax (differential equations)

LLM training:
- MaxText (Google's LLM training)
- Levanter

When to Choose Each

Choose PyTorch:

  • Production ML system
  • Using Hugging Face models (Transformers, Diffusers)
  • Team not specialized in functional programming
  • Need mature deployment tooling
  • Most academic and industry research today

Choose JAX:

  • Research on novel gradient-based algorithms
  • TPU training (native and optimal)
  • Functional programming style preferred
  • Meta-learning, implicit differentiation, custom derivatives
  • Working in Google/DeepMind research context
  • Need extreme performance on custom hardware

Bottom Line

PyTorch is the practical choice for almost everyone — its ecosystem dominance, Hugging Face integration, and production tooling make it the clear default. JAX is the research choice for teams working at the frontier of ML methods, especially those using TPUs or needing advanced gradient computation. Both are excellent tools; the gap in capabilities has narrowed significantly, but the ecosystem gap strongly favors PyTorch for most applications.