As a neural network trains, each layer's input distribution keeps shifting because the layers before it are changing — the network is constantly chasing a moving target. This is internal covariate shift. Batch Normalization (BatchNorm) fixes it by normalising each layer's inputs to zero mean and unit variance at every training step, then letting the network re-learn the optimal scale and shift.
The Four Steps of BatchNorm
For a mini-batch B = {x₁, x₂, ..., xₘ}:
Step 1 — Batch mean:
μ_B = (1/m) Σ xᵢ
Step 2 — Batch variance:
σ²_B = (1/m) Σ (xᵢ − μ_B)²
Step 3 — Normalise:
x̂ᵢ = (xᵢ − μ_B) / √(σ²_B + ε) → N(0, 1)
Step 4 — Scale and shift:
yᵢ = γ · x̂ᵢ + β → N(β, γ)
γ and β are learned parameters (like weights) — they restore the
network's ability to represent any distribution it needs.ε ≈ 1e-5 prevents division by zero; γ and β start at 1 and 0 then are learned by backprop
Training vs Inference
During training:
μ and σ² are computed fresh from each mini-batch.
The network sees slightly different normalisation every step → mild regularisation.
During inference:
No mini-batch is available.
Instead, running statistics accumulated during training are used:
μ_running ← momentum · μ_running + (1−momentum) · μ_batch
σ²_running ← momentum · σ²_running + (1−momentum) · σ²_batch
In PyTorch: model.eval() switches to running statistics automatically.
Forgetting this is a common bug — validation loss looks wrong.PyTorch tracks running_mean and running_var automatically in nn.BatchNorm layers
Why BatchNorm works so well
- Allows higher learning rates: normalised gradients flow more uniformly across layers — you can train with 5–10× larger learning rate without diverging.
- Reduces sensitivity to initialisation: bad weight initialisation causes layer distributions to drift; BatchNorm resets them every step.
- Acts as regularisation: the mini-batch statistics inject slight noise (different μ,σ per batch) — reduces need for dropout.
- Smoother loss landscape: gradients become less spiky, optimisation converges faster and more reliably.
- Mitigates vanishing/exploding gradients: activations stay in a reasonable range even in very deep networks.
Where to place BatchNorm
Common placement — after linear/conv, before activation:
x → Linear → BatchNorm → ReLU → next layer
Alternative (pre-norm, used in Transformers):
x → LayerNorm → Attention/FF → x (residual)
Do NOT apply BatchNorm to the output layer.
For small batch sizes (< 8), consider LayerNorm or GroupNorm instead
— BatchNorm statistics become unreliable with very few samples.ResNets, VGG, EfficientNet all use BatchNorm after conv; Transformers use LayerNorm