Sommaire
Si vous avez déjà lu la mise en œuvre d'un modèle d'apprentissage profond, il est probable que vous ayez déjà rencontré BatchNorm (Batch Normalization). Il s'agit d'une opération très courante qui est utilisée pour accélérer l'entraînement des grands modèles et pour stabiliser les modèles instables. Cependant, si vous êtes praticien, il est fort possible que vous ayez également eu du mal à subir cette opération, qui pose notoirement de nombreux problèmes. Dans cet article, nous passerons en revue les problèmes que nous rencontrons souvent et proposerons quelques solutions.
Qu'est-ce qu'une couche de normalisation par lots ?
BatchNorm vise à résoudre le problème du décalage des covariables. Cela signifie que pour une couche donnée d'un réseau profond, la sortie présente une moyenne et un écart type sur l'ensemble de données. Pendant l'entraînement, cette moyenne et cet écart type ne sont pas limités et peuvent évoluer de manière aléatoire, ce qui peut poser des problèmes de stabilité numérique. L'opération BatchNorm tente de résoudre ce problème en normalisant la sortie de la couche. Cependant, il est trop coûteux d'évaluer la moyenne et l'écart type sur l'ensemble de données, c'est pourquoi nous ne les évaluons que sur un lot de données.
Cela fonctionne bien dans la pratique, mais nous ne pouvons pas faire de même au moment de l'inférence, car nous recevons les données une par une, donc les moyennes n'ont plus de sens. Afin de résoudre ce problème, les implémentations modernes proposent de calculer une moyenne courante sur les données.
Le problème
En résumé, le comportement est différent entre l'entraînement et l'inférence. Au moment de l'entraînement t, mt, et σt sont utilisés, mais au temps d'inférence mt et σt sont utilisés. Cette différence est à l'origine de tous les maux, car les paramètres de validation et de formation peuvent être très différents. Plus précisément, à mesure que la quantité réelle évolue au cours de l'entraînement, la moyenne cumulée est souvent à la traîne, ce qui peut entraîner une différence significative. En principe, si le lot est volumineux et si le modèle converge correctement, ces quantités devraient devenir les mêmes. Mais dans la pratique, c'est souvent faux ou peu pratique. Par exemple, il ne sera pas évident si un écart important entre la perte d'entraînement et la perte de validation est dû à un surajustement important ou au fait que ces quantités n'ont pas encore convergé.
Plus dangereux encore, nous observons régulièrement que, bien que la perte d'entraînement converge vers une certaine valeur, la perte de validation peut rester considérablement plus élevée, car la moyenne et l'écart type du BatchNorm ne se stabilisent jamais. Nous, les auteurs, ne sommes pas tout à fait sûrs de la cause du problème, mais nous pensons que cela peut se produire lorsque le minimum est fortement dégénéré. Par exemple, dans un paysage de pertes tel qu'illustré ci-dessous, le modèle se déplacera de manière aléatoire dans la vallée circulaire, ce qui entraînera un retard permanent de la moyenne courante.

La solution
La première chose à faire si vous rencontrez ce problème est d'essayer quelques astuces standard. En voici quelques-unes typiques :
- Essayez d'utiliser une autre solution de normalisation (par ex. Norme de couche, Norme d'instance...) ;
- Augmenter la taille du lot, ce qui peut stabiliser l'estimation de la moyenne et de l'écart type entre les lots ;
- Jouez avec le paramètre momentum de la moyenne mobile. Il vous indique dans quelle mesure les lots précédents persistent dans la moyenne courante, c'est-à-dire dans quelle mesure les estimations peuvent « prendre du retard » ;
- Mélangez votre ensemble d'entraînement à chaque époque, afin d'éviter toute corrélation entre les points de données.
Cependant, parfois, ces astuces de base ne suffisent pas. Dans ce cas, nous proposons d'utiliser une astuce plus puissante.
Gardez à l'esprit que nous avons deux comportements différents de la couche BatchNorm :
- Dans ce que nous appellerons le mode d'estimation par lots, la moyenne et les écarts types sont estimés sur le lot. Il s'agit du mode utilisé lors de l'entraînement ;
- Dans ce que nous appellerons le mode d'inférence, la moyenne et l'écart type sont basés sur des estimations précédentes, c'est-à-dire sur la moyenne courante. C'est ce qui est généralement utilisé lors de la validation et de l'inférence.
Notre solution se compose de deux étapes ! Tout d'abord, nous désactivons la différence entre l'entraînement et la validation en utilisant toujours le mode d'estimation par lots. Deuxièmement, pour utiliser le modèle en production, nous devons encore estimer la moyenne et l'écart type pour pouvoir utiliser le mode inférence. Ainsi, une fois le modèle entraîné, nous calculons la moyenne et l'écart type à utiliser. Ce faisant, ils sont évalués sur un modèle à pondération fixe et nous évitons l'effet de « retard » décrit précédemment. Plus concrètement, après l'entraînement, nous figeons tous les poids du modèle et exécutons une époque plus tard pour estimer la moyenne mobile de l'ensemble de données.
Expérimenter la solution
Afin de montrer l'avantage de notre solution, faisons une petite expérience. Nous avons délibérément utilisé une très mauvaise architecture et l'avons entraînée avec un taux d'apprentissage relativement élevé, ce qui a donné naissance à un modèle avec BatchNorm instable. Le code écrit en Python et l'utilisation de PyTorch est disponible.
Le réseau est un empilement de 3 couches de convolution, avec une activation par BatchNorm et ReLU suivie d'une couche de regroupement moyenne globale. Nous l'avons entraîné sur MNIST pendant 10 époques à l'aide de l'algorithme d'optimisation Adam. La figure ci-dessous montre la précision de l'entraînement et de la validation par époque dans 4 modes :
- Mode 0 : aucune couche BatchNorm n'est utilisée.
- Mode 1 : BatchNorm de base sans aucune modification.
- Mode 2 : Almost Smart BatchNorm : nous avons activé les statistiques courantes à des fins d'inférence, mais nous n'avons pas utilisé l'époque du modèle 1 pour estimer la moyenne mobile des statistiques.
- Mode 3 : Smart BatchNorm : nous estimons sur 1 époque les statistiques moyennes du jeu de données avant le mode inférence.

Nous observons deux choses. Tout d'abord, BatchNorm contribue à augmenter la précision. Deuxièmement, sans notre solution, la métrique de validation est erratique et peu informative. Enfin, nous fournissons la précision du test pour les 4 situations.


Comme vous pouvez le constater, nous pourrions obtenir de meilleurs résultats en utilisant notre solution. Le troisième mode est vraiment mauvais : nous activons les statistiques en cours (mode inférence) mais nous n'estimons pas ces statistiques sur l'ensemble de données, donc lorsque nous testons dans des conditions d'inférence avec une taille de lot de 1, nous obtenons de mauvais résultats. Cela montre la nécessité de combiner les statistiques courantes au moment de l'inférence avec l'estimation des statistiques de l'ensemble de données sur une époque entière de l'ensemble de données avant d'utiliser le modèle pour l'inférence.
Cette solution est-elle parfaite ?
Non, évidemment pas ! Beaucoup de mauvaises choses peuvent encore se produire. Le plus délicat est que votre moyenne estimée et votre écart-type seront toujours différents de l'estimation par lots et que certains phénomènes vraiment étranges peuvent tout de même vous toucher durement. Par exemple, il a été démontré que certains modèles peuvent réellement coder des informations dans un bruit statistique. Heureusement, ces cas extrêmes sont très rares et l'expérience a montré que cette solution est assez robuste, elle ne devrait qu'améliorer vos performances et vous éviter bien des maux de tête. Si vous voulez éviter des comportements étranges avec vos couches BatchNorm, allez-y.
Image de fond réalisée par Pietro Jeng
À propos



.webp)
.webp)

.webp)