Deep Learning - Advanced - 12 min

Learn Batch Normalization

A free visual AI and machine learning lesson with an interactive 3D visualization, plain-English theory, and quiz.

Last updated: 2026-05-13.

As a neural network trains, each layer's input distribution keeps shifting because the layers before it are changing — the network is constantly chasing a moving target. This is internal covariate shift. Batch Normalization (BatchNorm) fixes it by normalising each layer's inputs to zero mean and unit variance at every training step, then letting the network re-learn the optimal scale and shift.

The Four Steps of BatchNorm

For a mini-batch B = {x₁, x₂, ..., xₘ}:

Step 1 — Batch mean:
  μ_B = (1/m) Σ xᵢ

Step 2 — Batch variance:
  σ²_B = (1/m) Σ (xᵢ − μ_B)²

Step 3 — Normalise:
  x̂ᵢ = (xᵢ − μ_B) / √(σ²_B + ε)     → N(0, 1)

Step 4 — Scale and shift:
  yᵢ = γ · x̂ᵢ + β                    → N(β, γ)

γ and β are learned parameters (like weights) — they restore the
network's ability to represent any distribution it needs.

ε ≈ 1e-5 prevents division by zero; γ and β start at 1 and 0 then are learned by backprop

Training vs Inference

During training:
  μ and σ² are computed fresh from each mini-batch.
  The network sees slightly different normalisation every step → mild regularisation.

During inference:
  No mini-batch is available.
  Instead, running statistics accumulated during training are used:
    μ_running ← momentum · μ_running + (1−momentum) · μ_batch
    σ²_running ← momentum · σ²_running + (1−momentum) · σ²_batch

In PyTorch: model.eval() switches to running statistics automatically.
Forgetting this is a common bug — validation loss looks wrong.

PyTorch tracks running_mean and running_var automatically in nn.BatchNorm layers

Why BatchNorm works so well

  • Allows higher learning rates: normalised gradients flow more uniformly across layers — you can train with 5–10× larger learning rate without diverging.
  • Reduces sensitivity to initialisation: bad weight initialisation causes layer distributions to drift; BatchNorm resets them every step.
  • Acts as regularisation: the mini-batch statistics inject slight noise (different μ,σ per batch) — reduces need for dropout.
  • Smoother loss landscape: gradients become less spiky, optimisation converges faster and more reliably.
  • Mitigates vanishing/exploding gradients: activations stay in a reasonable range even in very deep networks.

Where to place BatchNorm

Common placement — after linear/conv, before activation:
  x → Linear → BatchNorm → ReLU → next layer

Alternative (pre-norm, used in Transformers):
  x → LayerNorm → Attention/FF → x (residual)

Do NOT apply BatchNorm to the output layer.
For small batch sizes (< 8), consider LayerNorm or GroupNorm instead
— BatchNorm statistics become unreliable with very few samples.

ResNets, VGG, EfficientNet all use BatchNorm after conv; Transformers use LayerNorm

Practice questions

  1. What problem does Batch Normalization primarily solve?
  2. In the BatchNorm formula yᵢ = γ·x̂ᵢ + β, what are γ and β?
  3. During inference (test time), how does BatchNorm compute mean and variance?
  4. Why does BatchNorm allow training with a much larger learning rate?

Related AI learning resources

Premium lesson notes and simulations | AI project templates | More Deep Learning lessons