Symmetries in Overparametrized Neural Networks A Mean Field View

Properties
authors Javier Maass Martinez, Joaquin Fontbona
year 2024
url https://arxiv.org/abs/2405.19995

Abstract

We develop a Mean-Field (MF) view of the learning dynamics of overparametrized Artificial Neural Networks (NN) under data symmetric in law wrt the action of a general compact group G. We consider for this a class of generalized shallow NNs given by an ensemble of N multi-layer units, jointly trained using stochastic gradient descent (SGD) and possibly symmetry-leveraging (SL) techniques, such as Data Augmentation (DA), Feature Averaging (FA) or Equivariant Architectures (EA). We introduce the notions of weakly and strongly invariant laws (WI and SI) on the parameter space of each single unit, corresponding, respectively, to G-invariant distributions, and to distributions supported on parameters fixed by the group action (which encode EA). This allows us to define symmetric models compatible with taking N→∞ and give an interpretation of the asymptotic dynamics of DA, FA and EA in terms of Wasserstein Gradient Flows describing their MF limits. When activations respect the group action, we show that, for symmetric data, DA, FA and freely-trained models obey the exact same MF dynamic, which stays in the space of WI laws and minimizes therein the population risk. We also give a counterexample to the general attainability of an optimum over SI laws. Despite this, quite remarkably, we show that the set of SI laws is also preserved by the MF dynamics even when freely trained. This sharply contrasts the finite-N setting, in which EAs are generally not preserved by unconstrained SGD. We illustrate the validity of our findings as N gets larger in a teacher-student experimental setting, training a student NN to learn from a WI, SI or arbitrary teacher model through various SL schemes. We last deduce a data-driven heuristic to discover the largest subspace of parameters supporting SI distributions for a problem, that could be used for designing EA with minimal generalization error.