Training Techniques
Mixed Precision Training
Train with FP16 or BF16 for speed while keeping FP32 master weights for accuracy. Loss scaling, overflow prevention, and when mixed precision fails.
Prerequisites
Why This Matters
Training a large model in FP32 is significantly slower than FP16/BF16 on modern GPUs: 4-8x slower for matmul on H100/A100 tensor cores, and roughly 2x slower on non-tensor-core ops or older hardware. It also uses 2x more memory for weights and activations. Mixed precision training gives you most of this speedup while maintaining FP32 accuracy. Every large language model trained since 2018 uses some form of mixed precision. Understanding the underlying floating-point arithmetic is essential for diagnosing training failures.
Mental Model
Keep a "master copy" of weights in FP32. For each training step: cast weights to FP16/BF16, run the forward and backward pass in reduced precision (fast), then update the FP32 master weights with the computed gradients. The reduced-precision passes are fast because GPU tensor cores operate at 2x throughput for FP16/BF16 compared to FP32.
Formal Setup
Mixed Precision Training
A training procedure that maintains model weights in FP32 (the master weights) while computing forward activations and gradients in FP16 or BF16. Weight updates are applied in FP32:
where is the gradient cast back to FP32 before the update.
FP16 vs BF16
FP16 (IEEE half precision) uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The representable range is approximately (the figure is the minimum positive subnormal; the minimum positive normal is ).
BF16 (brain floating point) uses 1 sign bit, 8 exponent bits, and 7 mantissa bits. The representable range is approximately .
The critical difference: BF16 has the same exponent range as FP32 but lower precision. FP16 has higher precision but a much smaller exponent range. For training, the exponent range matters more than mantissa precision. Gradients that are very small (below ) underflow to zero in FP16 but are representable in BF16. Activations that are large (above 65504) overflow in FP16 but not in BF16.
This is why BF16 has largely replaced FP16 for training on hardware that supports it (A100 and later GPUs, TPUs).
Loss Scaling
Loss Scaling
Loss scaling multiplies the loss by a constant before the backward pass, then divides gradients by after:
By linearity of differentiation, . The purpose is to shift the gradient distribution into the representable range of FP16 during the backward pass.
In practice, dynamic loss scaling starts with a large (e.g., ) and halves whenever an overflow (NaN/Inf) is detected. If no overflow occurs for consecutive steps, is doubled. This adapts to the gradient magnitude throughout training.
Main Theorems
Loss Scaling Preserves Gradient Direction
Statement
Let be a differentiable loss function and a finite scaling factor. If no floating-point overflow occurs during computation of , then:
in exact arithmetic. In FP16 arithmetic, the scaled computation preserves gradient components that would otherwise underflow to zero, at the cost of potentially overflowing large components.
Intuition
Loss scaling shifts the entire gradient histogram to the right on a log scale. Components that were below the FP16 minimum () are lifted into representable range. Components near the FP16 maximum () may overflow. Dynamic scaling finds the sweet spot automatically.
Proof Sketch
By the chain rule, . Dividing by recovers exactly. In floating-point arithmetic, the multiplication by shifts the exponent of each gradient component by exponent positions, preventing underflow for components whose exponent was within of the minimum.
Why It Matters
Without loss scaling, FP16 training of deep networks fails because a large fraction of gradient components (often 50%+) fall below the FP16 minimum and become zero. Loss scaling is what makes FP16 training possible in practice. This is particularly important when combined with SGD convergence guarantees that assume nonzero gradient information.
Failure Mode
Loss scaling cannot help when gradients span a range wider than FP16 can represent (about 12 orders of magnitude from the smallest subnormal to the largest finite value). If the largest gradient component overflows even at scale , or the smallest underflows even at maximum scale, mixed precision with FP16 breaks. BF16 avoids this by having a much wider exponent range (about 76 orders of magnitude, matching FP32).
When Mixed Precision Fails
Gradient accumulation errors. When accumulating gradients across microbatches in FP16, the running sum can lose precision. Small gradient contributions get rounded away when added to a large accumulator. The fix: accumulate in FP32.
Attention logit overflow. In transformers, the attention logits can exceed 65504 in FP16, causing NaN. This happens with long sequences or poorly scaled attention. The fix: compute attention in FP32, or use BF16 which handles the range.
Small weight updates. When the learning rate is very small and the gradient is moderate, the update can underflow in FP16. The master weight strategy already handles this by updating in FP32, but naive implementations that skip master weights will fail.
Common Confusions
Mixed precision does not mean training in FP16
Mixed precision means using both FP16/BF16 and FP32 at different stages. Pure FP16 training (without FP32 master weights) diverges for most models. The "mixed" part is the key.
BF16 usually does not need loss scaling
Because BF16 has the same 8-bit exponent range as FP32, the gradient underflow problem that motivates FP16 loss scaling is essentially gone, and most BF16 training pipelines run without a scaling step. This is a practical advantage but not an unconditional guarantee: BF16's mantissa is much shorter than FP32 (7 bits vs 23), so there are settings — very small learning rates, gradient accumulation, second-order optimizers, FP8 / BF16 mixed precision — where round-off and accumulation issues can still cause silent precision loss. The right takeaway is "FP16 essentially requires loss scaling; BF16 typically does not", not "BF16 has no numerical failure modes".
Memory savings are not 2x
The master weights are still FP32. The memory savings come from FP16 activations (which dominate memory for large models) and FP16 gradients. For a model with parameters, you need bytes for master weights plus bytes for FP16 weights, versus bytes for FP32 only. Activation memory savings are model-dependent.
Summary
- Keep FP32 master weights, compute forward and backward in FP16/BF16
- Loss scaling prevents gradient underflow in FP16 by shifting the gradient distribution
- BF16 is preferred over FP16 for training because its wider exponent range eliminates most overflow and underflow issues
- Dynamic loss scaling adapts automatically; start high and halve on overflow
- Accumulate gradients in FP32 to avoid precision loss
Exercises
Problem
A gradient component has value in FP32. Will this underflow to zero in FP16? If you apply loss scaling with , what is the scaled value, and does it survive in FP16?
Problem
In a transformer with hidden dimension and sequence length , the attention logits are . Under the i.i.d. standard-normal baseline for entries, estimate the typical maximum logit. Then explain why FP16 overflow still occurs in practice despite this baseline being far below 65504.
References
Canonical:
- Micikevicius et al., "Mixed Precision Training" (2018), ICLR
- Kalamkar et al., "A Study of BFLOAT16 for Deep Learning Training" (2019)
Current:
- NVIDIA, "Training with Mixed Precision" documentation (2023)
- Dehghani et al., "The Efficiency Misnomer" (2022), Section 4
Last reviewed: April 26, 2026
Canonical graph
Required before and derived from this topic
These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.
Required prerequisites
6- Floating-Point Arithmeticlayer 0A · tier 1
- Adam Optimizerlayer 2 · tier 1
- Learning Rate Schedulinglayer 2 · tier 1
- WebGPU for Machine Learninglayer 0B · tier 2
- Running ML Workloads on GPUslayer 4 · tier 3
Derived topics
1- Activation Checkpointinglayer 3 · tier 2
Graph-backed continuations