Singular Learning Theory

NOTE: Work in Progress

Abstract


Introduction

Regular learning theory is lying to you: "overparametrized" models actually aren't overparametrized, and generalization is not just a question of broad basins.

600

The standard explanation for neural networks is that gradient descent settles in flat basins of the loss function. On the left, in a sharp minimum, the updates bounce the model around. Performance will vary wildly with new examples. On the right, in a flat minimum, the updates settle to zero. Performance is stable under small perturbations.

That's because loss basins actually aren't basins but valleys, and at the base of these valleys lie manifolds of constant, minimum loss. The higher the dimension of these "rivers", the lower the effective dimensionality of your model. Generalization is a balance between expressivity (more effective parameters) and simplicity (fewer effective parameters).

misc

Singular directions lower the effective dimensionality of your model. In this example, a line of degenerate points effectively restricts the two-dimensional loss surface to one dimension.

These manifolds correspond to the internal symmetries of NNs: continuous variations of a given network that perform the same calculation. Many of these symmetries are predetermined by the architecture and so are always present. We call these "generic". The more interesting symmetries are non-generic symmetries, which the model can form or break during training.

In this light, part of the power of NNs is that they can vary their effective dimensionality (thus also expressivity). Generality comes from a kind of "forgetting" in which the model throws out unnecessary dimensions. At the risk of being elegance-sniped, SLT seems like a promising route to develop a better understanding of training dynamics (and phenomenon such as sharp left turns and path-dependence). If we're lucky, SLT may even enable us to construct a grand unified theory of scaling.

A lot still needs to be done (esp. in terms of linking the Bayesian presentation of singular learning theory to conventional machine learning), but, from an initial survey, singular learning theory feels meatier than other explanations of generalization.1 So let me introduce you to the basics…

Singular Models

Maximum likelihood estimation is KL-divergence minimization.

We're aiming for a shallow introduction of questionable rigor. For full detail, I recommend Carroll's MSc thesis here (whose notation I am adopting).

The setting is Bayesian, so we'll start by translating the setup of "standard" regression problem to more appropriate Bayesian language.

We have some true distribution q(yx)q(y|x) and some model p(yx,w)p(y|x, w) parametrized by weights, wWRDw \in W \subseteq \mathbb R^D. Our aim is to learn the weights that make pp as "close" as possible to qq.

Given a dataset D=(xi,yi)i=1n\mathcal D = {(x_i, y_i)}*{i=1}^n, frequentist learning is usually formulated in terms of the empirical likelihood of our data (which assumes that each sample is i.i.d.):

p(yx,w)=Πi=1np(yixi,w). p(\mathbf y|\mathbf x, w) = \Pi*{i=1}^n p(y_i|x_i,w).

The aim of learning is to find the weights that maximize this likelihood (hence "maximum likelihood estimator"):

w=argmaxwp(yx,w). w^* = \text{argmax}_w\, p(\mathbf y|\mathbf x, w).

That is: we want to find the weights which make our observations as likely as possible.

In practice, because sums are easier than products and because we like our bits to be positive, we end up trying to minimize the negative log likelihood instead of the vanilla likelihood. That is, we're minimizing average bits of information rather than maximizing probabilities:

Ln(w):=1nlogp(yx,w)=1ni=1nlogp(yixi,w). L_n(w) := -\frac{1}{n}\log p(\mathbf y | \mathbf x, w) = -\frac{1}{n}\sum_{i=1}^n\log p(y_i|x_i, w).

If we define the empirical entropy, SnS_n, of the true distribution,

Sn:=1ni=1nlogq(yixi), S_n := -\frac{1}{n}\sum_{i=1}^n \log q(y_i|x_i),

then, since SnS_n is independent of ww, we find that minimizing Ln(w)L_n(w) is equivalent to minimizing the empirical Kullback-Leibler divergence, Kn(w)K_n(w), between our model and the true distribution:

Kn(w):=1ni=1nlogq(yixi)p(yixi,w)=Ln(w)Sn. K_n(w):= \frac{1}{n}\sum_{i=1}^n\log \frac{q(y_i|x_i)}{p(y_i|x_i, w)} = L_n(w) - S_n.

So maximizing the likelihood is not just some half-assed frequentist heuristic. It's actually an attempt to minimize the most straightforward information-theoretic "distance" between the true distribution and our model.

K(w):=DKL(q(y,x)p(y,xw))=Eq(y,x)[logq(yx)p(yx,w)]. K(w) := D_{\text{KL}}(q(y, x)||p(y, x | w)) = \mathbb E_{q(y, x)}\left[\log \frac{q(y|x)}{p(y|x, w)}\right].

The advantage of working with the KL-divergence is that it's bounded: K(w)0K(w) \geq 0 with equality iff q(yx,w)=p(yx)q(y|x, w) = p(y|x) almost everywhere.

In this frame, our learning task is not simply to minimize the KL-divergence, but to find the true parameters:

W0:={wWK(w)=0}={wWp(yx,w)=q(yx)}. W_0 := \{w \in W|K(w)=0\} = \{w \in W|p(y|x, w) = q(y|x)\}.

Note that it is not necessarily the case that a set of true parameters actually exists. If your model is insufficiently expressive, then the true model need not be realizable: your best fit may have some non-zero KL-divergence.

Still, from the perspective of generalization, it makes more sense to talk about true parameters than simply the KL-divergence-minimizing parameters. It's the true parameters that give us perfect generalization (in the limit of infinite data).

The Bayesian Information Criterion is a lie.

One of the main strengths of the Bayesian frame is that it lets enforce a prior φ(w)\varphi(w) over the weights, which you can integrate out to derive a parameter-free model:

p(yx)=Wp(yx,w)φ(w) dw. p(y|x) = \int_W p(y|x, w)\varphi(w)\ \text{d}w.

One of the main weaknesses is that this integral is often almost always intractable. So Bayesians make a concession to the frequentists with a much more tractable Laplace approximation (i.e., you approximate your model as quadratic/gaussian in the vicinity of the maximum likelihood estimator (MLE), w(0)w^{(0)}):2

K(w)12(ww(0))TI(w(0))(ww(0)), K(w) \approx \frac{1}{2}(w-w^{(0)})^T I(w^{(0)}) (w-w^{(0)}),

where I(w)I(w) is the Fisher information matrix:

Ij,k(w)=R(wjlogp(yx,w))(wklogp(yx,w))p(y,xw)dxdy. I_{j,k}(w)=\int_{\mathbb{R}}\left(\frac{\partial}{\partial w_j} \log p(y|x, w)\right)\left(\frac{\partial}{\partial w_k} \log p(y |x, w)\right) p(y, x|w) \text{d} x \text{d} y.

placeholder

The Laplace approximation is a probability theorist's Taylor approximation.

From this approximation, a bit more math gives us the Bayesian information criterion (BIC):

BIC=Ln(w0)+D2logn. \text{BIC} = L_n(w_0) + \frac{D}{2}\log n.

The BIC (like the related Akaike information criterion) is a criterion for model selection that penalizes complexity. Given two models, the one with the lower BIC tends to overfit less (/"generalize better").

The problem with regular learning theory is that deriving the BIC invokes the inverse, I1(w(0))I^{-1}(w^{(0)}), of the information matrix. If I(w(0))I(w^{(0)}) is non-invertible, then the BIC and all the generalization results that depend on it are invalid.

As it turns out, information matrices are pretty much never invertible for deep neural networks. So, we have to rethink our theory.

Singularities in the context of Algebraic Geometry

For an analytic function K:WR,xWK : W \to \mathbb R, x \in W is a critical point of KK if it has zero divergence, K(x)=0\nabla K(x) = 0. A singularity is a critical point that is also equal to zero, K(x)=0K(x) = 0.

Under these definitions, any true parameter ww^* is a singularity of the KL divergence. K(w)=0K(w^*)=0 follows from the definition of ww^*, and K(w)\nabla K(w^*) follows from the lower bound, K(w)0K(w) \geq 0.

So another advantage of the KL divergence over the NLL is that it gives us a cleaner lower bound, under which K(w)K(w^*) is a singularity for any true parameter ww^*.

We are interested in degenerate singularities — singularities that occupy a common manifold. For degenerate singularities, there is some continuous change to ww^* which leaves K(w)K(w^*) unchanged. That is, the surface is not locally parabolic.

placeholder

Non-degenerate singularities are locally parabolic. Degenerate singularities are not.

In terms of KK, this means that the Hessian at the singularity has at least one zero eigenvalue (equivalently, it is non-invertible). For the KL-divergence, the Hessian at a true parameter is precisely the Fisher information matrix we just saw.

Generic symmetries of NNs

Neural networks are full of symmetries. that let you change the model's internals without changing the overall computation. This is where our degenerate singularities come from.

The most obvious symmetry is that you can permute weights without changing the overall computation. Given any compatible two linear transformations, AA and BB (i.e., weight matrices), an element-wise activation function, ϕ\phi, and any permutation, PP,

BϕA=(BP1)ϕ(PA) B \circ \phi \circ A = (B \circ P^{-1}) \circ \phi \circ (P \circ A)

because permutations commute with ϕ\phi. The non-linearity of ϕ\phi means this isn't the case for invertible transformations in general.

(abcdefghi)(jklmnopqr)=(bacedfhgi)(mnojklpqr) \begin{pmatrix} \color{red} a & \color{blue} b & c\\ \color{red} d & \color{blue}e & f \\ \color{red}g & \color{blue} h & i \end{pmatrix} \cdot \begin{pmatrix} \color{red}j & \color{red}k & \color{red}l\\ \color{blue}m & \color{blue}n & \color{blue}o \\ p & q & r \end{pmatrix} =\begin{pmatrix} \color{blue} b & \color{red} a & c\\ \color{blue} e & \color{red}d & f \\ \color{blue}h & \color{red} g & i \end{pmatrix} \cdot \begin{pmatrix} \color{blue}m & \color{blue}n & \color{blue}o \\ \color{red}j & \color{red}k & \color{red}l\\ p & q & r \end{pmatrix}

An example using the identity function for ϕ\phi.

At least for this post, we'll ignore this symmetry as it is discrete, and we're interested in continuous symmetries that can give us degenerate singularities.

A more promising continuous symmetry is the following (for models that use ReLUs):

ReLU(x)=1αReLU(αx),α>0. \text{ReLU}(x) =\frac{1}{\alpha}\text{ReLU}(\alpha x),\quad \alpha > 0.

For a ReLU layer, you can continuously scale the incoming pre-activation as long as you inversely scale the outgoing activation. Since this symmetry is present over the entire parameter space, WW, nowhere in the weight space is safe from degeneracy.

As an exercise for the reader, can you think of any other generic symmetries in deep neural networks?3

Both of these symmetries are generic. They're an after-effect of the architecture choice, and are always active. The more interesting symmetries are non-generic symmetries — those that depend on ww.

Non-Generic Symmetries

The key observation of singular learning theory for neural networks is that neural networks can vary their effective parameter count.

Real Log Canonical Threshold (RLCT)

Zeta function of K(w)K(w)

ζ(z)=WK(w)zϕ(w)dw. \zeta(z)=\int_W K(w)^z \phi(w)\,\text{d}w.

where wW:ϕ(w)>0\forall w \in W: \phi(w)>0.

Analytically continue this to the whole complex plane with a Laurent expansion, then the first (large) pole is the RLCT.

Missing good image of what is going on here. Yes, the pole is the location of a singularity (ζ(z)\zeta(z) \to \infty) and its multiplicity is the order of the polynomial you need to approximate the local behavior of the corresponding zero of ζ1\zeta^{-1}. But how do I actually interpret zz?

Real log canonical threshold (λ\lambda)

  • The RLCT is the volume co-dimension (the number of effective parameters near the most singular point W0W_0).
  • For regular (non-singular) models, the RLCT is precisely D/2D/2
  • Why divide by two?

Orientation-reversing symmetries

Using ReLUs, we can imagine our network performing a kind of piece-wise high-dimensional splice approximation of the input function. For higher dimensions, we're looking at constant hypersurfaces.

The intersections of these surfaces are described by a linear equation of the parameters that sum to zero. That is, there's an orientation. If we reverse this orientation, we get the same line.

Degenerate nodes

In addition to these symmetries, when the model has more hidden nodes than truth, excess nodes are either degenerate or have the same activation boundary as another one.

NOTE: I'm still confused here.

INSERT comment on equivariance (link Distill article)


Glossary

  • If these parameters exist at all (i.e., this set is non-empty, and there is some choice of weights w0w_0 for which K(w0)=0K(w_0) = 0), we say our model is realizable. We'll assume this is the case from now on.
  • When every choice of weights corresponds to a unique model (i.e., the map θq(yx,θ)\theta \mapsto q(y|x, \theta) is injective for all x,yx, y), we say our model is identifiable.
  • If a model is identifiable and its Fisher information matrix is positive definite, then a model is regular. Otherwise, the model is strictly singular.

Singular learning theory kicks into gear when our models are singular. When the true parameters are

If the Hessian is always strictly positive definite (it has no zero eigenvalues for any θ\boldsymbol\theta), then an identifiable model is called regular. A non-regular model is called strictly singular.

Example: Pendulum

Our object of study are triples of the kind (p(yx,w),q(yx),ϕ(w))(p(y|x,w), q(y|x), \phi(w)), where p(yx,w)p(y|x, w) is a model p,1 of some unknown true model, q(yx)q(y|x), with a prior over the weights, ϕ(w)\phi(w).

The model itself is a regression model on ff:

p(yx,w)N(yf(x,y),1M) p(y|x, w) \sim \mathcal{N}(y|f(x,y), \mathbb 1_M)

We have some probability distribution p(yx)p(y|x) and a model q(yx,θ)q(y|x, \theta) parameterized by θ\theta. Our aim is to learn the weights θ\theta so as to capture the true distribution, and we assume p(x,y)p(x, y) is realizable

Let's start with an example to understand how learning changes when your models are singular.

You have a pendulum with some initial angular displacement x0x_0 and velocity v0v_0. Newton tells us that at time tt, it'll be in position

xt=f(x0,v0,g,t)=x0cos(gt)+v0gsin(gt). x_t = f(x_0, v_0, g, t) = x_0 \cos(\sqrt{g}\cdot t) + \frac{v_0}{\sqrt{g}}\sin(\sqrt{g}\cdot t).

The problem is that we live in the real world. Our measurement of xtx_t is noisy:

p(xtx0,v0,g,t)=N(xtf(x0,v0,g,t), ϵ2) p(x_t|x_0, v_0, g, t) = \mathcal N(x_t|f(x_0, v_0, g, t),\ \epsilon^2)

What we'd like to do is to learn the "true" parameters (x0,v0,g,t)(x_0^*, v_0^*, g^*, t^*) from our observations xtx_t. That gives us a problem: For any set of true parameters, the following would output the same value:

{(x0,av,a2g,1at)  a>0} \{(x_0, av^*, a^2 g^*, \frac{1}{a}t^*)\ |\ a > 0\}

That is: our map from parameters to models is non-injective. Multiple parameters determine the same model. We call these models strictly singular.

  • Aim: Modeling a pendulum given a noisy estimate of (x,y)(x, y) at time tt.

  • The parameters of our model are (λ,g,t)(\lambda, g, t) (initial velocity, gravitational acceleration, time of measurement), and the model is:

  • The map from parameters to models is non-injective. That is, the function, ff, is exactly the same after a suitable mapping (like g4gg \to 4g, λ2λ\lambda \to 2 \lambda, tt/2t \to t/2).
  • You can reparameterize this model to get rid of the degeneracy (λ=λ/g\lambda' = \lambda/\sqrt{g}, t=gtt' = \sqrt g \cdot t):
f(x,y;λ,t)=xcos(t)+λsin(t).f(x, y; \lambda', t') = x \cos(t') + \lambda'\sin(t').
  • But that may actually make the parameters less useful to reason about, and, in general, may make the "true" model harder to find.

(If you look at the tλt - \lambda plane, you get straight line level sets, same for tgt - \sqrt g plane).

Relation to Hessian of K(w)K(w) (KL-divergence between a parameter and the true model)

…TODO

Connection to basin broadness

…TODO

Emphasize this early on.

Question

  • Does SGD actually reach these solutions? For a given loss, if we are uniformly distributed across all weights with that loss, we should end up in simpler solutions, right? Does this actually happen though?
  • Is part of the value of depth that you create more ReLU like symmetries. Can you create equally successful shallow, wide models if you hardcode additional symmetries?

Footnotes

  1. E.g.: that explicit regularization enforces simpler solutions (weight decay is a Gaussian prior over weights), that SGD settles in broader basins that are more robust to changes in parameters (=new samples), that NNs have Solomonoff-like inductive biases [1], or that highly correlated weight matrices act as implicit regularizers [2]. 2

  2. This is just a second-order Taylor approximation modified for probability distributions. That is, the Fisher information matrix gives you the curvature of the negative log likelihood: it tells you how many bits you gain (=how much less likely your dataset becomes) as you move away from the minimum in parameter space.

  3. Hint: normalization layers, the encoding/unencoding layer of transformers / anywhere else without a privileged basis.