Blog
AI & Machine Learning

The Danger of Batch Normalization in Deep Learning

Reading time:
6
min
Published on:
Jun 11, 2025

If you have ever read the implementation of a deep learning model, chances are you have already encountered BatchNorm (Batch Normalization). This is a very common operation that is used to accelerate the training of large models and to stabilise unstable ones.

However, if you are a practitioner, it’s quite possible that that you also struggled with that operation, which notoriously poses many problems. In this article, we will review the problems we often encounter and propose a few solutions.

Fix‑it Checklist (2025 Edition)

Fix-it Checklist: Batch-Norm Gotchas
Quick Win Why It Helps
Use Ghost Batch Norm when effective batch < 256 Keeps statistics stable in small-batch or multi-GPU runs
Switch to SyncBatchNorm across devices Shares running stats so validation ≈ training
Try Group Norm or Layer Norm if memory is tight Batch-size-agnostic and just as fast on PyTorch 2.7
Fuse norm + activation with EvoNorm for CNNs Fewer FLOPs, can boost top-1 by ~0.5 pp
On video/streaming inference, prefer Group Norm Prevents future-frame data leakage
Run one frozen-weights epoch to recompute stats Eliminates “lag” between training and inference
Compile with torch.compile(..., mode="reduce-overhead")
and switch tensors to channels_last
Yields a ~10–30 % throughput bump on modern GPUs

What is a Batch Normalisation Layer ?

BatchNorm aims at solving the problem of the covariate shift. What this means is that for a given layer in a deep network, the output has a mean and standard deviation across the dataset. During training, this mean and standard deviation are unconstrained and can randomly evolve, which can pose some numerical stability issues. The BatchNorm operation attempts to remove this problem by normalising the layer’s output. However, it is too costly to evaluate the mean and the standard deviation on the whole dataset, so we only evaluate them on a batch of data.

First standard deviation mathematical equation
Normalized input, mean and standard deviation computed over the N elements of the batch i

This works well in practice, but we cannot do the same at inference time, because we receive data one by one, so the averages do not make sense anymore. In order to solve this problem, modern implementations propose to calculate a running average over the data.

Second standard deviation mathematical equation
Running mean and running standard deviation over samples, (1 – ε) is the momentum (or persistency) of previous samples.

The Problem

In summary, the behaviour is different between training and inference. At training time t,mt, and σt are used, but at inference time mt and σt are used. This difference is the root of all evil as the metrics in validation and in training can be very different.

More precisely, as the real quantity evolves during training, the running average will often lag behind, which can cause a significant difference. In principle, if the batch is large, and if the model converges fine, then those quantities should become the same.

But in practice, it is often wrong or impractical. For example, it will not be obvious if a large discrepancy between the training and validation loss is due to severe overfitting, or because those quantities did not converge yet.

More dangerously, we regularly observe that although the training loss converges to some value, the validation loss can remain considerably higher, due to the mean and standard deviation of the BatchNorm never stabilising.

We, the authors, are not entirely sure of the cause of the problem, but we believe that this can happen when the minimum is heavily degenerated. For example, in a loss landscape like illustrated below, the model will move randomly in the circular valley, causing the running average to lag behind forever.

Mexican hat loss landscape plot
“Mexican hat” loss landscape where many different parameter configurations can lead to the same global minimum

Why things break (recap + new pitfalls)

  1. Small‑batch chaos  When batch < 32, the estimate of µ/σ² is so noisy that validation accuracy can lag behind training by 5 pp or more. GBN fixes this by slicing the batch into virtual chunks and normalising each chunk.
  2. Multi‑GPU drift  Each GPU used to keep its own stats; torch.nn.SyncBatchNorm now broadcasts them to stay in sync.
  3. Video time‑leak  For sequential data BN may “peek” at future frames because the batch mixes timesteps; medical‑video papers showed inflated AUCs until they swapped BN for GN.
  4. Spectral bias in coordinate networks  A 2024 CVPR study found that BN actually helps tiny MLPs learn high‑frequency details by widening the NTK eigen‑spectrum. So no, BN is not obsolete.

Modern alternatives at a glance

Modern Alternatives to Batch Norm
Layer Sweet Spot One-liner
Layer Norm Transformers, RNNs, seq-to-seq nn.LayerNorm(d_model) — batch-size-agnostic
Group Norm Vision when batch is tiny nn.GroupNorm(32, C) — 32 groups works well
EvoNorm-S0/B0 CNNs where you’d fuse ReLU + BN evonorm.S0(k, groups=16) — fewer FLOPs
DyT (Dynamic Tanh) “Normalization-free” Transformers Experimental; worth keeping an eye on

Pro tip: You can swap BN layers on a pre‑trained conv net for GN without retraining from scratch — often needs only a few epochs of fine‑tuning.

The Solution

The first thing to do if you encounter this problem is to try a few standard tricks. Here are some typical ones:

  • Try to use another normalisation solution (i.e. LayerNorm, InstanceNorm…);
  • Increase the batch size, which can stabilise the estimation of the mean and standard deviation among the batch;
  • Play with the momentum parameter of the moving average. It tells you how much previous batches are persistent in the running average, in other words, how much the estimates can “lag behind”;
  • Shuffle your training set in every epoch, to avoid correlation between data points.

However, sometimes those basic tricks will not suffice. In that case we propose to use a more powerful astuce.

Keep in mind that we have two different behaviours of the BatchNorm layer:

  1. In what we will call Batch Estimation Mode, the mean and standard deviations are estimated on the batch. This is the mode used in training;
  2. In what we will call Inference mode, the mean and standard deviation are based on previous estimations, meaning the running average. This is what is usually used during validation and inference.

Our solution is two steps!

First, we deactivate the difference between training and validation by always using the batch estimation mode.

Secondly, in order to use the model in production, we still need to estimate the mean and standard deviation to be able to use inference mode. So, after the model is trained, we calculate the mean and standard deviation to be used.

By doing so they are evaluated on a model with fixed weight and we avoid the previously described “lagging behind” effect. More concretely, after training, we freeze all the weights of the model and run one epoch in to estimate the moving average on the whole dataset.

The 2025 PyTorch code snippet

import torch, torch.nn as nn
from torch.compile import compile
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor

class TinyNet(nn.Module):
    def __init__(self, groups=8):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.GroupNorm(groups, 64),        # safer than BatchNorm on tiny batches
            nn.SiLU(),
            nn.Conv2d(64, 10, 1)
        )

    def forward(self, x):
        return self.conv(x).mean((-2, -1))

model = TinyNet().to(memory_format=torch.channels_last)  # channels_last for free speed
model = torch.compile(model, mode="reduce-overhead")     # PyTorch ≥ 2.7

Swapping nn.GroupNorm for nn.BatchNorm2d and adding is two lines if you do need BN stats.

Experimenting with the Solution

In order to show the advantage of our solution, let’s do small experiment. We purposely used a very bad architecture and trained it with a relatively high learning rate, resulting in a model with unstable BatchNorm. The code written in Python and using PyTorch is available.

The network is a stack of 3 convolution layers, with BatchNorm and ReLU activation followed by a global average pooling layer. We trained it on MNIST for 10 epochs using the Adam Optimization Algorithm. In the figure below are the training and validation accuracy per epoch in 4 modes:

  • Mode 0: No BatchNorm layers are used.
  • Mode 1: Basic BatchNorm with no modifications.
  • Mode 2: Almost Smart BatchNorm: we activated the running stats for inference but we didn’t run the model 1 epoch to estimate the moving average of stats.
  • Mode 3: Smart BatchNorm: we estimate on 1 epoch the average stats of the dataset before inference mode.
Train and validation accuracy line charts
Train and validation accuracy of a 3-layers toy network on MNIST (10 epochs). On the mode 0, we have no BatchNorm, on mode 1: basic BatchNorm, and mode 3 is smart BatchNorm

We observe two things. First, BatchNorm helps increase the accuracy. Secondly, without our solution, the validation metric is erratic and uninformative. Finally, we provide the test accuracy for all 4 situations.

First test accuracy column chart
Test accuracy after 10 epochs for the 4 modes described above (successively 0, 1, 2, 3)
Second test accuracy column chart
Test accuracy after 10 epochs for the modes 0, 1 and 3 successively

As you can see, we could obtain better results using our solution. The third mode is really bad: we activate the running stats (inference mode) but we don’t estimate those statistics on the dataset so when testing in inference conditions with a batch size of 1, we have poor results. This shows the necessity to combine the running statistics at inference time with the estimation of the dataset statistics on a whole epoch of the dataset before using the model for inference.

Is that Solution Perfect?

No, obviously not! Many bad things can still happen. The trickiest is that your estimated mean and standard deviation will still be different from the batch estimate and some really odd phenomena can still hit you hard.

For example, it has been shown that some models can actually encode information in statistical noise. Fortunately, those extreme cases are very scarce and experience has shown that this solution is quite robust, it should only improve your performances and save you a lot of headaches.

If you want to avoid weird behaviors with your BatchNorm layers, go for it.

AI & Machine Learning

Next steps

Try out our products for free. No commitment or credit card required. If you want a custom plan or have questions, we’d be happy to chat.

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
0 Comments
Author Name
Comment Time

Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse varius enim in eros elementum tristique. Duis cursus, mi quis viverra ornare, eros dolor interdum nulla, ut commodo diam libero vitae erat. Aenean faucibus nibh et justo cursus id rutrum lorem imperdiet. Nunc ut sem vitae risus tristique posuere. uis cursus, mi quis viverra ornare, eros dolor interdum nulla, ut commodo diam libero vitae erat. Aenean faucibus nibh et justo cursus id rutrum lorem imperdiet. Nunc ut sem vitae risus tristique posuere.

FAQ

Can I remove BatchNorm layers once the model is trained?

Yes, but you must first fold the layer’s affine parameters (γ, β) and the running mean/variance into the preceding convolution or linear layer. Otherwise predictions will shift and accuracy will drop.

Why does BatchNorm sometimes hurt when my batch size is only 8 samples?

With such a small batch, the estimates of mean and variance are extremely noisy—validation accuracy can lag behind training by several percentage points. Switch to Ghost Batch Norm, Group Norm, or accumulate gradients to simulate a larger batch.

Is Layer Normalization always a better choice than BatchNorm?

No. Layer Norm is independent of batch size and shines in transformer or sequence models, but it can slow down convolutional networks. For pixel-heavy workloads, BatchNorm (or Group/EvoNorm) still gives better speed-accuracy trade-offs.