Standard machine learning pulls all data to a central server, then trains a model. Federated learning flips that: the data stays where it is (your phone, hospital, bank), and the model travels. Each client trains locally, sends only model updates to a server, and the server averages them into a new global model. Useful when data is private, regulated, or simply too large to move.
The FedAvg algorithm (Google, 2016)
- Server initialises a global model w_0.
- Each round t: server samples K clients and broadcasts w_t.
- Each chosen client runs E local epochs of SGD on its own data, producing w_t^k.
- Clients send w_t^k (or just the delta Δw) back to the server.
- Server aggregates: w_{t+1} = Σ (n_k / n) · w_t^k (weighted by client data size).
- Repeat for hundreds of rounds.
FedAvg aggregation:
w_{t+1} = Σ_{k=1..K} (n_k / n_total) · w_t^k
n_k = number of samples on client k
n_total = total samples across all participating clientsLarger clients get proportionally more weight in the average.
Why this is hard
- Non-IID data: client A has mostly cats, client B has mostly dogs — global averaging can stall.
- Communication cost: model weights can be huge (100 MB+); sending every round is expensive over mobile networks.
- Stragglers: 90% of clients return in 30 seconds, 10% take 5 minutes — wait or drop them?
- System heterogeneity: phones have wildly different CPUs, batteries, network reliability.
- Privacy leakage: even gradients can leak training data via inversion attacks. Differential privacy noise is sometimes added.