Why This Matters
A 7B-parameter transformer in bfloat16 needs about 14 GB for weights, but Adam training commonly stores weights, gradients, two moment vectors, and often a master fp32 copy. A rough training footprint is 16 bytes per parameter, so 7B parameters is about 112 GB before activations. One 80 GB accelerator is already short.
Even when the model fits, a single accelerator may take weeks. Distributed training trades extra communication for more compute devices. The central question is numeric: does the saved compute time exceed the time spent moving gradients, activations, parameters, and optimizer state?
Core Definitions
Data parallelism
Data parallelism replicates the model on each worker, splits a global batch into per-worker mini-batches, computes gradients locally, then averages gradients with an all-reduce before the optimizer step. If there are workers and per-worker batch size , the global batch size is .
Tensor parallelism
Tensor parallelism shards a single tensor operation across workers. A large matrix multiplication, attention projection, or MLP projection is split along rows, columns, or heads. Workers exchange partial outputs with all-reduce, all-gather, or reduce-scatter.
Pipeline parallelism
Pipeline parallelism assigns contiguous layer ranges to different workers. Micro-batches flow through the stages so that stage 0 can process micro-batch 2 while stage 1 processes micro-batch 1.
Collective operation
A collective operation is a communication primitive involving a group of workers. The common collectives are all-reduce, all-gather, and reduce-scatter. Training systems usually call vendor libraries such as NCCL rather than writing collectives directly.
ZeRO and fully sharded data parallel
ZeRO partitions training state across data-parallel workers. Stage 1 shards optimizer state, stage 2 also shards gradients, and stage 3 also shards parameters. Fully sharded data parallel follows the same idea: gather parameters for a layer when needed, compute, then release or reshard them.
Data Parallel Training
In synchronous data parallel training, every worker starts a step with identical parameters. Worker computes a local gradient on its local mini-batch. The gradient used by the optimizer is the arithmetic mean:
For four workers, suppose a scalar parameter has local gradients 0.20, 0.10, -0.05, and 0.15. The all-reduce sum is 0.40, and the averaged gradient is 0.10. With learning rate 0.001, SGD changes the parameter by -0.0001.
A minimal C++ sketch hides many details but shows the ordering invariant. No worker may update parameters before the averaged gradient is available.
// One synchronous data-parallel step.
// all_reduce_sum writes the elementwise sum into grad.
for (int step = 0; step < steps; ++step) {
zero_grad(model);
forward_backward(local_batch, model, grad);
all_reduce_sum(grad.data(), grad.numel(), group);
for (size_t k = 0; k < grad.numel(); ++k) {
grad[k] /= group.size();
}
adam_update(model.params(), grad, optimizer_state);
}
The communication volume is large. A model with 1B parameters has a bfloat16 gradient of 2 GB. Ring all-reduce over workers moves approximately GB per worker for a tensor of size GB. On a 900 GB/s intra-node fabric, the bandwidth-only lower bound is about 3.9 ms. Across a 400 Gb/s link, about 50 GB/s before protocol overhead, the same traffic has a lower bound near 70 ms.
Gradient bucketing reduces exposed latency. Rather than waiting until backprop finishes, frameworks place gradients into buckets, for example 25 MB each. When a bucket is complete, its all-reduce starts while earlier layers continue backprop. The step time is then closer to the maximum of compute and communication than their sum.
Tensor Parallel Matrix Multiplication
Transformer layers contain large matrix multiplications. For , let have shape and have shape for the MLP expansion. With and , a small example is:
Let column parallelism split into two column blocks and , each of shape if the full output width is 16. Worker 0 computes , and worker 1 computes . The full is the concatenation , so the next operation may require an all-gather.
For row parallelism, split and along the hidden dimension:
Each worker computes a partial output of shape . The workers must all-reduce the partial outputs. This is why Megatron-style tensor parallelism alternates split dimensions to keep some operations local and place collectives at known points.
The byte layout matters. Eight bfloat16 values occupy 16 bytes. If a row-sharded activation has four bfloat16 elements per worker, two workers hold these bytes:
worker 0 X0, shape [1,4], bfloat16:
value: 1.0 2.0 3.0 4.0
bytes: 80 3f 00 40 40 40 80 40
worker 1 X1, shape [1,4], bfloat16:
value: 5.0 6.0 7.0 8.0
bytes: a0 40 c0 40 e0 40 00 41
The split is a logical tensor partition, not a compression. Total bytes remain the same, but no single device stores the entire activation slice.
Pipeline Parallel Training
Pipeline parallelism divides the layer sequence. For a 24-layer transformer on four devices, a simple split assigns layers 0-5, 6-11, 12-17, and 18-23. If a mini-batch is split into micro-batches, the forward pass forms a time grid.
time: 0 1 2 3 4 5
stage 0: F0 F1 F2 F3 F4 F5
stage 1: . F0 F1 F2 F3 F4
stage 2: . . F0 F1 F2 F3
stage 3: . . . F0 F1 F2
The empty cells are pipeline bubbles. For stages and micro-batches, a simple forward-only utilization approximation is . With and , that is . Larger reduces bubbles but stores more activations unless activation recomputation is used.
Backward scheduling adds weight-version constraints. GPipe uses synchronous flush scheduling, which keeps weight versions simple. PipeDream-style schedules can keep devices busier but must track which weight version produced each activation. If the wrong version is used in backward, the gradient no longer corresponds to the forward pass that created the loss.
Pipeline communication sends activations forward and activation gradients backward. If a boundary activation is in bfloat16, its size is bytes. With micro-batch , sequence length 2048, and hidden size 4096, one boundary activation is 33,554,432 bytes, or 32 MiB. Each boundary sends that in forward and again in backward.
ZeRO, FSDP, and Sharded State
The Adam state footprint dominates training. Per parameter, a common mixed-precision accounting is 2 bytes for bf16 parameter, 2 bytes for bf16 gradient, 4 bytes for fp32 master parameter, 4 bytes for first moment, and 4 bytes for second moment. That is 16 bytes per parameter.
For a 10B-parameter model, the unsharded state is 160 GB. Across eight data-parallel workers:
replicated DP per worker:
params bf16 20 GB
grads bf16 20 GB
master fp32 40 GB
Adam m fp32 40 GB
Adam v fp32 40 GB
total 160 GB
ZeRO-1 per worker:
params bf16 20 GB
grads bf16 20 GB
optimizer shard 15 GB
total 55 GB
ZeRO-2 per worker:
params bf16 20 GB
grads shard 2.5 GB
optimizer shard 15 GB
total 37.5 GB
ZeRO-3 or FSDP per worker:
params shard 2.5 GB
grads shard 2.5 GB
optimizer shard 15 GB
total 20 GB
Stage 3 pays in communication. Before computing a layer, workers all-gather the parameter shard for that layer. After backward, gradients are reduce-scattered so each worker keeps only its shard. FSDP implementations often wrap module blocks so that parameter all-gather and reshard happen at module boundaries.
This changes the memory timeline. A worker may briefly hold a full layer’s parameters, but not the full model. The right wrap size balances communication frequency against peak memory.
Collective Operations and Interconnects
All-reduce, all-gather, and reduce-scatter are the vocabulary of distributed training. If each worker starts with one tensor chunk, all-gather leaves every worker with all chunks. Reduce-scatter first reduces elementwise, then leaves each worker with one reduced shard. All-reduce is equivalent to reduce-scatter followed by all-gather.
A ring all-reduce splits the tensor into chunks. Each worker repeatedly sends one chunk to its neighbor and receives one chunk from the other neighbor. It has high bandwidth use for large tensors and cost proportional to bytes per worker. A tree all-reduce uses a reduction tree and a broadcast tree. It has lower latency terms, roughly message steps, which makes it better for small tensors.
Interconnect placement changes the parallelism choice. Within an 8-GPU node, NVLink on systems such as H100 SXM has about 900 GB/s aggregate GPU-to-GPU bandwidth per GPU. Across nodes, 200 to 400 Gb/s InfiniBand or RoCE gives about 25 to 50 GB/s before overhead. A tensor-parallel group that communicates every transformer block is usually kept inside a node. Data parallelism or ZeRO groups are more likely to cross nodes because they can communicate fewer, larger buckets.
Key Result
For one training step, a useful planning model is:
With overlap, communication that fits under backprop compute is hidden. For a gradient bucket of size , ring all-reduce on workers has the bandwidth term:
Here is effective link bandwidth. For , MiB, and GB/s, the lower bound is about 0.051 ms. With GB/s, it is about 0.92 ms. Latency, PCIe hops, kernel launch overhead, and contention add to this bound.
The invariant is stricter than the formula. Synchronous DP requires identical parameters at the start of each step. Tensor parallelism requires matching shards for a single logical tensor operation. Pipeline parallelism requires backward to use the same weight version as forward, unless the algorithm explicitly tolerates stale weights. ZeRO-3 and FSDP require a parameter to be gathered before its computation and resharded only after all consumers finish.
Common Confusions
All-reduce is not the same as all-gather
All-gather concatenates shards. It does not add values. If worker 0 has gradient [1, 2] and worker 1 has [3, 4], all-gather gives both workers [1, 2, 3, 4]. All-reduce sum gives both workers [4, 6].
Tensor parallelism does not reduce total FLOPs
Sharding a matrix multiplication divides FLOPs across devices, but the global multiplication is the same size. The win is lower wall time per layer when communication is smaller than the saved local compute.
ZeRO-3 saves memory but adds parameter traffic
A ZeRO-3 worker does not hold all parameters all the time. It must all-gather full parameters for the active module. If modules are wrapped too finely, many small all-gathers expose latency.
A lost worker is different from a slow worker
A straggler delays the step but eventually contributes gradients. A lost worker breaks the collective. In synchronous training, the remaining workers cannot finish an NCCL all-reduce that includes the missing rank.
Exercises
Problem
Four data-parallel workers train a model with two scalar parameters. Their local gradients are , , , and . Compute the averaged gradient and the SGD update for learning rate 0.01.
Problem
A tensor has 64 MiB of bf16 gradients. Estimate the per-worker byte traffic for ring all-reduce on eight workers. Then estimate the bandwidth-only time on 900 GB/s and 50 GB/s links.
Problem
A 12B-parameter model uses the 16-byte-per-parameter Adam accounting from this page. Compute per-worker memory for replicated DP and for ZeRO-3 across eight workers. Ignore activations and allocator overhead.
References
Canonical:
- Vaswani et al., Attention Is All You Need (2017), §§3.1-3.3, transformer layer shapes and attention projections
- NVIDIA, CUDA C++ Programming Guide (2024), §§5.3, 8.2, memory hierarchy, streams, and asynchronous execution
- Williams, Waterman, and Patterson, Roofline: An Insightful Visual Performance Model for Multicore Architectures (2009), §§1-3, compute and bandwidth ceilings
- Rajbhandari et al., ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (SC 2020), §§3-4, ZeRO stages and partitioned optimizer state
- Huang et al., GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (NeurIPS 2019), §§2-3, micro-batch pipeline scheduling
- Narayanan et al., Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (SC 2021), §§3-4, tensor and pipeline parallel transformer training
Accessible:
- PyTorch documentation, FullyShardedDataParallel, API notes and sharding strategy descriptions
- NVIDIA NCCL documentation, Collective Operations, definitions of all-reduce, all-gather, and reduce-scatter
- Hugging Face documentation, Model Parallelism, overview of data, tensor, pipeline, and ZeRO-style sharding
Next Topics
- /computationpath/roofline-model
- /computationpath/cuda-execution-model
- /computationpath/gpu-memory-hierarchy
- /computationpath/inference-serving-basics
- /topics/scaling-laws-for-language-models