Feed

Spooky action at a distance in the loss landscape

Not all global minima of the (training) loss landscape are created equal.

Even if they achieve equal performance on the training set, different solutions can perform very differently on the test set or out-of-distribution. So why is it that we typically find "simple" solutions that generalize well?

In a previous post, I argued that the answer is "singularities" — minimum loss points with ill-defined tangents. It's the "nastiest" singularities that have the most outsized effect on learning and generalization in the limit of large data. These act as implicit regularizers that lower the effective dimensionality of the model.

Even after writing this introduction to "singular learning theory", I still find this claim weird and counterintuitive. How is it that the local geometry of a few isolated points determines the global expected behavior over all learning machines on the loss landscape? What explains the "spooky action at a distance" of singularities in the loss landscape?

Today, I'd like to share my best efforts at the hand-waving physics-y intuition behind this claim.

It boils down to this: singularities translate random motion at the bottom of loss basins into search for generalization.

Random walks on the minimum-loss sets

Let's first look at the limit in which you've trained so long that we can treat the model as restricted to a set of fixed minimum loss points1.

Here's the intuition pump: suppose you are a random walker living on some curve that has singularities (self-intersections, cusps, and the like). Every timestep, you take a step of a uniform length in a random available direction. Then, singularities act as a kind of "trap." If you're close to a singularity, you're more likely to take a step towards (and over) the singularity than to take a step away from the singularity.

It's not quite an attractor (we're in a stochastic setting, where you can and will still break away every so often), but it's sticky enough that the "biggest" singularity will dominate your stable distribution.

In the discrete case, this is just the well-known phenomenon of high-degree nodes dominating most of expected behavior of your graph. In business, it's behind the reason that Google exists. In social networks, it's similar to how your average friend has more friends than you do.

singularity.png

To see this, consider a simple toy example: take two polygons and let them intersect at a single point. Next, let a random walker run loose on this setup. How frequently will the random walker cross each point?

300

Take two or more 1D lattices with toroidal boundary conditions and let them intersect at one point. In the limit of an infinite polygon/lattice, you end up with a normal crossing singularity at the origin.

If you've taken a course in graph theory, you may remember that the stable distribution weights nodes in proportion to their degrees. For two intersecting lines, the origin is twice as likely as the other points. For three intersecting lines, it's three times as likely, and so on…

700

The size of the circle shows how likely that point is under empirical simulation. The stationary distribution puts as many times more weight on the origin as there are intersecting lines.

Now just take the limit of infinitely large polygons/step size to zero, and we'll recover the continuous case we were originally interested in.

Brownian motion near the minimum-loss set

Well, not quite. You see, restricting ourselves to motion along the minimum-loss points is unrealistic. We're more interested in messy reality, where we're allowed some freedom to bounce around the bottoms of loss basins.2

This time around, the key intuition-pumping assumption is to view the limiting behavior of stochastic gradient descent as a kind of Brownian motion — a source of randomness that no longer substantially improves loss but just jiggles us between solutions that are equivalent from the perspective of the training set.

To understand these dynamics, we can just study the more abstract case of Brownian motion in some continuous energy landscape with singularities.

Consider the potential function given by

U(x)=amin((x0b)2,(x1b)2). U(\boldsymbol x) = a \cdot \min((x_0-b)^2, (x_1-b)^2).

This is plotted on the left side of the following figure. The right side depicts the corresponding stable distribution on the right as predicted by "regular" physics.

valleys.png

An energy landscape whose minimum loss set has a normal crossing singularity at the origin. Toroidal boundary conditions as in the discrete case.

Simulating Brownian motion in this well generates an empirical distribution that looks rather different from the regular prediction…

valleys 1.png

As in the discrete case, the singularity at the origin gobbles up probability mass, even at finite temperatures and even for points away from the minimum loss set.

Takeaways

To summarize, the intuition3 is something like this: in the limiting case, we don't expect the model to learn much from any one additional sample. Instead, the randomness in drawing the new sample acts as Brownian motion that lets the model explore the minimum-loss set. Singularities are a trap for this Brownian motion which allow the model to find well-generalizing solutions just by moving around.

In short, singularities work because they transform random motion into useful search for generalization.

You can find the code for these simulations here and here.

Footnotes

  1. So technically, in singular learning theory we treat the loss landscape as changing with each additional sample. Here, we're considering the case that the landscape is frozen, and new samples act as a kind of random motion along the minimum-loss set.

  2. We're still treating the loss landscape as frozen but will now allow departures away from the minimum loss points.

  3. Let me emphasize: this is hand-waving/qualitative/physics-y jabber. Don't take it too seriously.

How to Perturb Weights

I'm running a series of experiments that involve some variation of: (1) perturb a weight initialization; (2) train the perturbed and baseline models in parallel, and (3) track how the perturbation grows/shrinks over time.

Naively, if we're interested in a perturbation analysis of the choice of weight initialization, we prepare some baseline initialization, w0\mathbf w_0, and then apply i.i.d. Gaussian noise, δ\boldsymbol \delta, to each of its elements, δiN(0,ϵ2)\delta_i \sim \mathcal N(0, \epsilon^2). (If we want, we can let this vary layer-by-layer and let it depend on, for example, the norm of the layer it's being applied to.)

The problem with this strategy is that the perturbed weights w=w0+δ\mathbf w = \mathbf w_0 + \boldsymbol\delta are, in general, no longer sampled from the same distribution as the baseline weights.

There is nothing wrong with this per se, but it introduces a possible confounder (the thickness). This is especially relevant if we're interested specifically in the question of how behavior changes with the size of a perturbation, this problem introduces a possible confounder. As responsible experimentalists, we don't like confounders.

Fortunately, there's an easy way to "clean up" Kaiming He to make it better suited to this perturbative analysis.

Kaiming initialization lives in a hyperspherical shell

Consider a matrix, w(l)\mathbf w^{(l)}, representing the weights of a particular layer ll with shape (Din(l),Dout(l+1))(D_\mathrm{in}^{(l)}, D_\mathrm{out}^{(l+1)}). Din(l)D_\mathrm{in}^{(l)} is also called the fan-in of the layer, and Din(l+1)D_\mathrm{in}^{(l+1)} the fan-out. For ease of presentation, we'll ignore the bias, though the following reasoning applies equally well to the bias.

We're interested in the vectorized form of this matrix, w(l)RD(l)\vec w^{(l)} \in \mathbb R^{D^{(l)}}, where D(l)=Din(l)×Dout(l+1)D^{(l)} =D_\mathrm{in}^{(l)} \times D_\mathrm{out}^{(l+1)}.

In Kaiming initialization, we sample the components, wi(l)w_i^{(l)}, of this vector, i.i.d. from a normal distribution with mean 0 and variance σ2\sigma^2 (where σ2=2Din(l)\sigma^2 = \frac{2}{D_\mathrm{in}^{(l)}}).

Geometrically, this is equivalent to sampling from a hyperspherical shell, SD1S^{D-1} with radius Dσ\sqrt{D}\sigma and (fuzzy) thickness, δ\delta. (Ok, so technically, because the radius can vary from layer-to-layer, it's a hyperellipsoidal shell.)

This follows from some straightforward algebra (dropping the superscript ll for simplicity):

E[w2]=E[i=1Dwi2]=i=1DE[wi2]=i=1Dσ2=Dσ2, \mathbb E[|\mathbf w|^2] = \mathbb E\left[\sum_{i=1}^D w_i^2\right] = \sum_{i=1}^D \mathbb E[w_i^2] = \sum_{i=1}^D \sigma^2 = D\sigma^2,

and

δ2var[w2]=E[(i=1Dwi2)2]E[i=1Dwi2]2=i,j=1DE[wi2wj2](Dσ2)2=ijDE[wi2]E[wj2]+i=1DE[wi4](Dσ2)2=D(D1)σ4+D(3σ4)(Dσ2)2=2Dσ4. \begin{align} \delta^2 \propto \mathrm{var} [|\mathbf w|^2] &= \mathbb E\left[\left(\sum_{i=1}^D w_i^2\right)^2\right] - \mathbb E\left[\sum_{i=1}^D w_i^2\right]^2 \\ &= \sum_{i, j=1}^D \mathbb E[w_i^2 w_j^2] - (D\sigma^2)^2 \\ &= \sum_{i \neq j}^D \mathbb E[w_i^2] \mathbb E[w_j^2] + \sum_{i=1}^D \mathbb E[w_i^4]- (D\sigma^2)^2 \\ &= D(D-1) \sigma^4 + D(3\sigma^4) - (D\sigma^2)^2 \\ &= 2D\sigma^4. \end{align}

So the thickness as a fraction of the radius is

δDσ=2DσD=2σ=2Din(l), \frac{\delta}{\sqrt{D}\sigma} = \frac{\sqrt{2D}\sigma}{\sqrt{D}} = \sqrt{2}\sigma = \frac{2}{\sqrt{D_\mathrm{in}^{(l)}}},

where the last equality follows from the choice of σ\sigma for Kaiming initialization.

This means that for suitably wide networks (Din(l)D_\mathrm{in}^{(l)} \to \infty), the thickness of this shell goes to 00.

Taking the thickness to 0

This suggests an alternative initialization strategy. What if we immediately take the limit Din(l)D_\text{in}^{(l)} \to \infty, and sample directly from the boundary of a hypersphere with radius Dσ\sqrt{D}\sigma, i.e., modify the shell thickness to be 00.

This can easily be done by sampling each component from a normal distribution with mean 0 and variance 11 and then normalizing the resulting vector to have length Dσ\sqrt{D}\sigma (this is known as the Muller method).

Perturbing Weight initialization

The modification we made to Kaiming initialization was to sample directly from the boundary of a hypersphere, rather than from a hyperspherical shell. This is a more natural choice when conducting a perturbation analysis, because it makes it easier to ensure that the perturbed weights are sampled from the same distribution as the baseline weights.

Geometrically, the intersection of a hypersphere SDS^D of radius w0=w0w_0=|\mathbf w_0| with a hypersphere SDS^D of radius ϵ\epsilon that is centered at some point on the boundary of the first hypersphere, is a lower-dimensional hypersphere SD1S^{D-1} of a modified radius ϵ\epsilon'. If we sample uniformly from this lower-dimensional hypersphere, then the resulting points will follow the same distribution over the original hypersphere.

This suggests a procedure to sample from the intersection of the weight initialization hypersphere and the perturbation hypersphere.

First, we sample from a hypersphere of dimension D1D-1 and radius ϵ\epsilon' (using the same technique we used to sample the baseline weights). From a bit of trigonometry, see figure below, we know that the radius of this hypersphere will be ϵ=w0cosθ\epsilon' = w_0\cos \theta, where θ=cos1(1ϵ22w02)\theta = \cos^{-1}\left(1-\frac{\epsilon^2}{2w_0^2}\right).

400

Next, we rotate the vector so it is orthogonal to the baseline vector w0\mathbf w_0. This is done with a Householder reflection, HH, that maps the current normal vector n^=(0,,0,1)\hat{\mathbf n} = (0, \dots, 0, 1) onto w0\mathbf w_0:

H=I2ccTcTc, H = \mathbf I - 2\frac{\mathbf c \mathbf c^T}{\mathbf c^T \mathbf c},

where

c=n^+w^0, \mathbf c = \hat{\mathbf n} + \hat {\mathbf w}_0,

and w^0=w0w0\hat{\mathbf w}_0 = \frac{\mathbf w_0}{|w_0|} is the unit vector in the direction of the baseline weights.

Implementation note: For the sake of tractability, we directly apply the reflection via:

Hy=y2cTycTcc. H\mathbf y = \mathbf y - 2 \frac{\mathbf c^T \mathbf y}{\mathbf c^T\mathbf c} \mathbf c.

Finally, we translate the rotated intersection sphere along the baseline vector, so that its boundary goes through the intersection of the two hyperspheres. From the figure above, we find that the translation has the magnitude w0=w0cosθw_0' = w_0 \cos \theta.

By the uniform sampling of the intersection sphere and the uniform sampling of the baseline vector, we know that the resulting perturbed vector will have the same distribution as the baseline vector, when restricted to the intersection sphere.

sampling-perturation 1.png

Random Walks on Hyperspheres

Neural networks begin their lives on a hypersphere.

As it turns out, they live out the remainder of their lives on hyperspheres as well. Take the norm of the vectorized weights, w\mathbf w, and plot it as a function of the number of training steps. Across many different choices of weight initialization, this norm will grow (or shrink, depending on weight decay) remarkably consistently.

In the following figure, I've plotted this behavior for 10 independent initializations across 70 epochs of training a simple MNIST classifier. The same results hold more generally for a wide range of hyperparameters and architectures.

Pasted image 20230124102455.png

The norm during training, divided by the norm upon weight initialization. The parameter ϵ controls how far apart (as a fraction of the initial weight norm) the different weight initializations are.

All this was making me think that I disregarded Brownian motion as a model of SGD dynamics a little too quickly. Yes, on lattices/planes, increasing the number of dimensions makes no difference to the overall behavior (beyond decreasing the variance in how displacement grows with time). But maybe that was just the wrong model. Maybe Brownian motion on hyperspheres is considerably different and is actually the thing I should have been looking at.

Asymptotics (t,nt, n \to \infty)

We'd like to calculate the expected displacement (=distance from our starting point) as a function of time. We're interested in both distance on the sphere and distance in the embedding space.

Before we tackle this in full, let's try to understand Brownian motion on hyperspheres in a few simplifying limits.

Unlike the Euclidean case, where having more than one dimension means you'll never return to the origin, living on a finite sphere means we'll visit every point infinitely many times. The stationary distribution is the uniform distribution over the sphere, which we'll denote U(n)\mathcal U^{(n)} for SnS^n.

IMG_671D5890C083-1.jpeg

So our expectations end up being relatively simple integrals over the sphere.

The distance from the pole at the bottom to any other point on the surface depends only on the angle from the line intersecting the origin and the pole to the line intersecting the origin and the other point,

z(θ)=(rsinθ)2(rrcosθ)2=2rsin(θ2). z(\theta) = \sqrt{(r\sin \theta)^2 - (r-r\cos \theta)^2} = 2r\sin \left(\frac{\theta}{2}\right).

This symmetry means we only have to consider a single integral from the bottom to to the top of Sn1S^{n-1} "slices." Each slice has as much probability associated to it as the corresponding Sn1S^{n-1} has surface area times the measure associated to θ\theta (=r dθ=r\ \mathrm d \theta). This surface area varies with the angle from the center, as does the distance of that point from the origin. Put it together and you get some nasty integrals.

IMG_63F07F66DFB1-1.jpeg

From this point on, I believe I made a mistake somewhere (the analysis for nn \to \infty is correct).

If we let Sn(x)S_{n}(x) denote the surface area of SnS^{n} with radius rr, we can fill in the formula from Wikipedia:

Sn1(x)=2πn2Γ(n2)xn1. S_{n-1}(x) = \frac{2\pi^{\frac{n}{2}}}{\Gamma(\frac{n}{2})} x^{n-1}.

Filling in x(θ)=rsinθx(\theta) =r \sin\theta, we get an expression for the surface area of each slice as a function of θ\theta,

Sn1(x(θ))=Sn1(r)sinn1(θ), S_{n-1}(x(\theta)) = S_{n-1}(r) \sin^{n-1}(\theta),

so the measure associated to each slice is

dμ(θ)=Sn1(r)sinn1(θ) rdθ. \mathrm d\mu(\theta) = S_{n-1}(r)\sin^{n-1}(\theta)\ r\, \mathrm d\theta.

nn \to \infty

We can simplify further, as nn\to\infty, the sinn1(θ)\sin^{n-1}(\theta) will tend towards a "comb function," that is zero everywhere except for θ=(m+1/2)π\theta = (m + 1/2)\pi for integer mm, where it is equal to 11 if mm is even, and 1-1 if mm is odd.

Within the above integration limits (0,π)(0, \pi), this means all of the probability becomes concentrated at the middle, where θ=π/2\theta=\pi/2.

So in the infinite-dimensional limit, the expected distance from the origin is the distance between a point on the middle and one of its poles, which is 2r\sqrt{2} r.

General case

An important lesson I've learned here is: "try more keyword variations in Google Scholar and Elicit before going off and solving these nasty integrals yourself because you will probably make a mistake somewhere and someone else has probably already done the thing you're trying to do."

Anyway, here's the thing where someone else has already done the thing I was trying to do. [1] gives us the following form for the expected cosine of the angular displacement with time:

cosθ(t)=exp(D(n1)t/R2).\langle\cos \theta(t)\rangle=\exp \left(-D(n-1) t / R^2\right).

A little trig tells us that the expected displacement in the embedding space is

z(t)=r2(1cosθ), \langle z(t)\rangle = r \sqrt{2\cdot(1-\langle \cos\theta\rangle)},

which gives us a set of nice curves depending on the diffusion constant DD:

output.png

The major difference with Brownian motion on the plane is that expected displacement asymptotes at the limit we argued for above (2r\sqrt{2} r).

When I had initially looked at curves like the following (this ones from [numerical simulation](https://github.com/jqhoogland/experiminis/blob/main/brownian_hyperspheres.ipynb, not the formula), I thought I had struck gold.

rw_sphere.png

I've been looking for a model that predicts square root like growth at the very start, followed by a long period of nearly linear growth, followed by asymptotic behavior. Eyeballing this, it looked like my searching for the long plateau of linear growth had come to an end.

Of course, this was entirely wishful thinking. Human eyes aren't great discriminators of lines. That "straight line" is actually pretty much parabolic. As I should have known (from the Euclidean nature of manifolds), the initial behavior is incredibly well modeled by a square root.

As you get further away from the nearly flat regime, the concavity only increases. This is not the curve I was looking for.

rw-comparison.png

So my search continues.

Neural networks generalize because of this one weird trick

Produced as part of the SERI ML Alignment Theory Scholars Program - Winter 2022 Cohort

A big thank you to all of the people who gave me feedback on this post: Edmund Lao, Dan Murfet, Alexander Gietelink Oldenziel, Lucius Bushnaq, Rob Krzyzanowski, Alexandre Variengen, Jiri Hoogland, and Russell Goyder.

Statistical learning theory is lying to you: "overparametrized" models actually aren't overparametrized, and generalization is not just a question of broad basins.

slt-generalization-std.png

The standard explanation thrown around here for why neural networks generalize well is that gradient descent settles in flat basins of the loss function. On the left, in a sharp minimum, the updates bounce the model around. Performance varies considerably with new examples. On the right, in a flat minimum, the updates settle to zero. Performance is stabler under perturbations.

To first order, that's because loss basins actually aren't basins but valleys, and at the base of these valleys lie "rivers" of constant, minimum loss. The higher the dimension of these minimum sets, the lower the effective dimensionality of your model.1 Generalization is a balance between expressivity (more effective parameters) and simplicity (fewer effective parameters).

slt-generalization-better.png

Symmetries lower the effective dimensionality of your model. In this example, a line of degenerate points effectively restricts the two-dimensional loss surface to one dimension.

In particular, it is the singularities of these minimum-loss sets — points at which the tangent is ill-defined — that determine generalization performance. The remarkable claim of singular learning theory (the subject of this post), is that "knowledge … to be discovered corresponds to singularities in general" [1]. Complex singularities make for simpler functions that generalize further.

singularities-good.png

The central claim of singular learning theory is that the singularities of the set of minima of the loss function determine learning behavior and generalization. Models close to more complex singularities generalize further.

Mechanistically, these minimum-loss sets result from the internal symmetries of NNs2: continuous variations of a given network's weights that implement the same calculation. Many of these symmetries are "generic" in that they are predetermined by the architecture and are always present. The more interesting symmetries are non-generic symmetries, which the model can form or break during training.

In terms of these non-generic symmetries, the power of NNs is that they can vary their effective dimensionality. Generality comes from a kind of internal model selection in which the model finds more complex singularities that use fewer effective parameters that favor simpler functions that generalize further.

Complex Singularities    Fewer Parameters    Simpler Functions    Better Generalization \small\text{Complex Singularities} \iff \text{Fewer Parameters} \iff \text{Simpler Functions} \iff \text{Better Generalization}

At the risk of being elegance-sniped, SLT seems like a promising route to develop a better understanding of generalization and the dynamics of training. If we're lucky, SLT may even enable us to construct a grand unified theory of scaling.

A lot still needs to be done (in terms of actual calculations, the theorists are still chewing on one-layer tanh models), but, from an initial survey, singular learning theory feels meatier than other explanations of generalization. It's more than just meatiness; there's a sense in which singular learning theory is a non-negotiable prerequisite for any theory of deep learning. Let's dig in.

Back to the Bayes-ics

Singular learning theory begins with four things:

  • The "truth", q(x)q(x), which is some distribution that is generating our samples;
  • A model, p(xw)p(x|w), parametrized by weights wWRdw \in \mathcal W \subset \mathbb R^d, where W\mathcal W is compact;
  • A prior over weights, φ(w)\varphi(w);
  • And a dataset of samples Dn={X1,,Xn}D_{n}= \{X_{1}, \dots, X_{n}\}, where each random variable XiX_{i} is i.i.d. according to q(x)q(x).

Here, I'm follow the original formulation and notation of Watanabe [1]. Do note that most of this presentation transfers straightforwardly from the context of density estimation (modeling q(x)q(x)) to other problems like regression and classification (modeling q(yx)q(y|x)) [2]. I also am deeply indebted to Carroll's Msc. Thesis [2] and the wonderful seminars and notes at metauni [3].

The low-level aim of "learning" is to find the optimal weights, ww, for the given dataset. As good Bayesians, this has a very specific and constrained meaning:

p(wDn)=p(Dnw) φ(w)p(Dn). p(w|D_n) = \frac{p(D_n|w)\ \varphi(w)}{p(D_n)}.

The higher-level aim of "learning" is to find the optimal model, p(xw)p(x|w), for the given dataset. Rather than try to find the weights that maximize the likelihood or even the posterior, the true aim of a Bayesian is to find the model that maximizes the model evidence,

p(Dn)=Wp(Dnw) ϕ(w) dw. p(D_n) = \int_\mathcal W p(D_n|w)\ \phi(w)\ \mathrm dw.

The fact that the Bayesian paradigm can integrate out its weights to make statements over entire model classes is one of its main strengths. The fact that this integral is often almost always intractable is one of its main weaknesses. So the Bayesians make a concession to the frequentists with a much more tractable Laplace approximation: we find a choice of weights, w(0)w^{(0)}, that maximizes the likelihood and then approximate the distribution as Gaussian in the vicinity of that point.

Pasted image 20230115201321.pngThe Laplace Approximation is just a probability theorist's (second-order) Taylor expansion.

This is justified on the grounds that as the dataset grows (nn\to\infty), thanks to the central limit theorem, the distribution becomes asymptotically normally (cf. physicists and their "every potential is a harmonic oscillator if you look closely enough / keep on lowering the temperature.").

From this approximation, a bit more math leads us to the following asymptotic form for the negative log evidence (in the limit nn\to \infty):

logp(Dn)logp(Dnw0)accuracy+d2lognsimplicity, -\log p(D_n) \approx \underset{\text{accuracy}}{\underbrace{-\log p(D_n|w_0)}} + \underset{\text{simplicity}}{\underbrace{\frac{d}{2}\log n}},

where dd is the dimensionality of parameter space.

This formula is known as the Bayesian Information Criterion (BIC), and it (like the related Akaike information criterion) formalizes Occam's razor in the language of Bayesian statistics. We can end up with models that perform worse as long as they compensate by being simpler. (For the algorithmic-complexity-inclined, the BIC has an alternate interpretation as a device for minimizing the description length in an optimal coding context.)

Unfortunately, the BIC is wrong. Or at least the BIC doesn't apply for any of the models we actually care to study. Fortunately, singular learning theory can compute the correct asymptotic form and reveal its much broader implications.

Statistical learning theory is built on a lie

The key insight of Watanabe is that when the parameter-function map,

Wwp(w) \mathcal W \ni w \to p(\cdot|w)

is not one-to-one, things get weird. That is, when different choices of weights can implement the same functions, the tooling of conventional statistical learning theory to break down. We call such models "non-identifiable".

non-identifiable.png

When the parameter-function map is not one-to-one, the right object of study is not parameter space but function/distribution space.

Take the example of the Laplace approximation. If there's a local continuous symmetry in weight space, i.e., some direction you can walk that doesn't affect the probability density, then your density isn't locally Gaussian.

laplace-approx.png

The Laplace approximation breaks down when there is a direction of perfect flatness.

Even if the symmetries are non-continuous, the model will not in general be asymptotically normal. In other words, the standard central limit theorem does not hold.

The same problem arises if you're looking at loss landscapes in standard presentations of machine learning. Here, you'll find attempts to measure basin volume by fitting a paraboloid to the Hessian of the loss landscape at the final trained weights. It's the same trick, and it runs into the same problem.

This isn't the kind of thing you can just solve by adding a small ϵ\epsilon to the Hessian and calling it a day. There are ways to recover "volumes", but they require care. So, as a practical takeaway, if you ever find yourself adding ϵ\epsilon to make your Hessians invertible, recognize that those zero directions are important to understanding what's really going on in the network. Offer those eigenvalues the respect they deserve.

paraboloid-bad.png

Adding epsilon to fudge your paraboloids is a dirty, insidious practice.

The consequence of these zeros (and, yes, they really exist in NNs) is that they reduce the effective dimensionality of your model. A step in these directions doesn't change the actual model being implemented, so you have fewer parameters available to "do things" with.

So the basic problem is this: almost all of the models we actually care about (not just neural networks, but Bayesian networks, HMMs, mixture models, Boltzmann machines, etc.) are loaded with symmetries, and this means we can't apply the conventional tooling of statistical learning theory.

Learning is physics with likelihoods

Let's rewrite our beloved Bayes' update as follows,

p(wDn)=1Znφ(w) enβLn(w), p(w|D_n) = \frac{1}{Z_n} \varphi(w)\ e^{-n\beta L_n(w)},

where Ln(w)L_n(w) is the negative log likelihood,

Ln(w):=1nlogp(Dnw)=1ni=1nlogp(xiw), L_n(w) := -\frac{1}{n}\log p(D_n|w) = -\frac{1}{n}\sum\limits_{i=1}^n \log p(x_i|w),

and ZnZ_n is the model evidence,

Zn:=p(Dn)=Wφ(w) enβLn(w) dw. Z_n := p(D_n) = \int_W \varphi(w)\ e^{-n\beta L_n(w)}\ \mathrm dw.

Notice that we've also snuck in an inverse "temperature", β>0\beta > 0, so we're now in the tempered Bayes paradigm [4].

The immediate aim of this change is to emphasize the link with physics, where ZnZ_n is the preferred notation (and "partition function" the preferred name). The information theoretic analogue of the partition function is the free energy,

Fn:=logZn, F_n := -\log Z_n,

which which will be the central object of our study.

Under the definition of a Hamiltonian (or "energy function"),

Hn(w):=nLn(w)1βlogφ(w), H_n(w) := nL_n(w) - \frac{1}{\beta}\log\varphi(w),

the translation is complete: statistical learning theory is just mathematical physics where the Hamiltonian is the random process given by the log likelihood ratio function. Just as the geometry of the energy landscape determines the behavior of the physical systems we study, the geometry of the log likelihood ends up determining the behavior of the learning systems we study.

In terms of this physical interpretation, the a posteriori distribution is the equilibrium state corresponding to this empirical Hamiltonian. The importance of the free energy is that it is the minimum of the free energy (not of the Hamiltonian) that determines the equilibrium.

Our next step will be to normalize these quantities of interest to make them easier to work with. For the negative log likelihood, this means subtracting its minimum value.3

But that just gives us the KL divergence,

Kn(w)=Ln0(w):=Ln(w)Sn=1ni=1nlogq(Xi)p(Xiw), \begin{align} K_n(w) &= L^0_n(w)\\ &:= L_n(w) - S_n\\ &=\frac{1}{n} \sum\limits_{i=1}^n \log\frac{q(X_i)}{p(X_i|w)}, \end{align}

where SnS_n is the empirical entropy,

Sn:=1ni=1nlogq(Xi), S_n := -\frac{1}{n} \sum\limits_{i=1}^n \log q(X_i),

a term that is independent of ww.

mle-is-klm.png

The empirical Kullback-Leibler divergence is just a rescaled and shifted version of the negative log likelihood. Maximum likelihood estimation is equivalent to minimizing the empirical KL divergence.

Similarly, we normalize the partition function to get

Zn0:=Zni=1nq(Xi)β, Z^0_n := \frac{Z_n}{\prod_{i=1}^n q(X_i)^\beta},

and the free energy to get

Fn0:=logZn0. F^0_n := -\log Z_n^0.

This lets us rewrite the posterior as

p(wDn)=1Zn0 φ(w) enβKn(w). p(w|D_n)=\frac{1}{Z^0_n}\ \varphi(w)\ e^{-n\beta K_n(w)}.

The more important aim of this conversion is that now the minima of the term in the exponent, K(w)K(w), are equal to 0. If we manage to find a way to express K(w)K(w) as a polynomial, this lets us to pull in the powerful machinery of algebraic geometry, which studies the zeros of polynomials. We've turned our problem of probability theory and statistics into a problem of algebra and geometry.

Why "singular"?

Singular learning theory is "singular" because the "singularities" (where the tangent is ill-defined) of the set of your loss function's minima,

W0:={w0WK(w0)=0}, \mathcal W_0 := \{w_0 \in \mathcal W|K(w_0)=0\},

determine the asymptotic form of the free energy. Mathematically, W0\mathcal W_0 is an algebraic variety, which is just a manifold with optional singularities where it does not have to be locally Euclidean.

Pasted image 20230113172055.png

Example of the curve y2=x2+x3y^2=x^2+x^3 (equivalently, the algebraic variety of the polynomial f(x,y)=x2+x3y2f(x, y) =x^2 + x^3 - y^2). There's a singularity at the origin. [Source]

By default, it's difficult to study these varieties close to their singularities. In order to do so anyway, we need to "resolve the singularities." We construct another well-behaved geometric object whose "shadow" is the original object in a way that this new system keeps all the essential features of the original.

It'll help to take a look at the following figure. The idea behind resolution of singularities is to create a new manifold U\mathcal U and a map g:UWg:\mathcal U \to \mathcal W, such that K(g(u))K(g(u)) is a polynomial in the local coordinates of U\mathcal U. We "disentangle" the singularities so that in our new coordinates they cross "normally".

resolution-of-singularities.png

Based on Figure 2.5 of [1]. The lines represent the points that are in W0. The colors are just there to help you keep track of the points.

Because this "blow up" creates a new object, we have to be careful that the quantities we end up measuring don't change with the mapping — we want to find the birational invariants.

We are interested in one birational invariant in particular: the real log canonical threshold (RLCT). Roughly, this measures how "bad" a singularity is. More precisely, it measures the "effective dimensionality" near the singularity.

After fixing the central limit theorem to work in singular models, Watanabe goes on to derive the asymptotic form of the free energy as nn \to \infty ,

Fn=nβSn+λlogn(m1)loglogn+FR(ξ)+op(1), F_n = n\beta S_n + \lambda \log n - (m-1) \log\log n + F^{R}(\xi) + o_p(1),

where, mm is the "multiplicity" associated to the RLCT, FR(ξ)F^{R}(\xi) is a random variable, and op(1)o_p(1) is a random variable that converges (in probability) to zero.

The important observation here is that the global behavior of your model is dominated by the local behavior of its "worst" singularities.

For regular (=non-singular) models, the RLCT is d/2d/2, and with the right choice of inverse temperature, the formula above simplifies to

FnnSn+d2logn(for regular models), F_n \approx nS_n +\frac{d}{2}\log n \quad\text{(for regular models)},

which is just the BIC, as expected.

The free energy formula generalizes the BIC from classical learning theory to singular learning theory, which strictly includes regular learning theory as a special case. We see that singularities act as a kind of implicit regularization that penalizes models with higher effective dimensionality.

Phase transitions are singularity manipulations

Minimizing the free energy is maximizing the model evidence, which, as we saw, is the preferred Bayesian way of doing model selection. Other paradigms may disagree4, but at least among us this makes minimizing the free energy the central aim of statistical learning.

As in statistical learning, so in physics.

In physical systems, we distinguish microstates, such as the particular position and speed of every particle in a gas, with macrostates, such as the values of the volume and pressure. The fact that the mapping from microstates to macrostates is not one-to-one is the starting point for statistical physics: uniform distributions over microstates lead to much more interesting distributions over macrostates.

Often, we're interested in how continuously varying our levers (like temperature or the positions of the walls containing our gas) leads to discontinuous changes in the macroscopic parameters. We call these changes phase transitions.

The free energy is the central object of study because its derivatives generate the quantities we care about (like entropy, heat capacity, and pressure). So a phase transition means a discontinuity in one of the free energy's derivatives.

So too, in the setting of Bayesian inference, the free energy generates the quantities we care about, which are now quantities like the expected generalization loss,

Gn=EXn+1[Fn+1]Fn. G_n = \mathbb E_{X_{n+1}} [F_{n+1}] - F_n.

Except for the fact that the number of samples, nn, is discrete, this is just a derivative.5

So too, in learning, we're interested in how continuously changing either the model or the truth leads to discrete changes in the functions we implement and, thereby, to discontinuities in the free energy and its derivatives.

One way to subject this question to investigation is to study how our models change when we restrict our models to some subset of parameter space, W(i)W\mathcal W^{(i)} \subset \mathcal W. What happens when as vary this subset?

Recall that the free energy is defined as the negative log of the partition function. When we restrict ourselves to W(i)\mathcal W^{(i)}, we derive a restricted free energy,

Fn(W(i)):=logZn(W(i))=logW(i)Wφ(w) enβLn(w) dw=nβSn(W(i))+λ(i)logn(m(i)1)loglogn+FR(ξ)+op(1), \begin{align} F_n(\mathcal W^{(i)}) &:=-\log Z_n(\mathcal W^{(i)})\\ &=-\log \int_{\mathcal W^{(i)} \subset \mathcal W}\varphi(w)\ e^{-n\beta L_n(w)}\ \mathrm dw\\ &= n\beta S_n(\mathcal W^{(i)}) + \lambda^{(i)} \log n - (m^{(i)}-1) \log\log n + F^{R}(\xi) + o_p(1), \end{align}

which has a completely analogous asymptotic form (after swapping out the integrals over all of weight space with integrals over just this subset). The important difference is that the RLCT in this equation is the RLCT associated to the largest singularity in W(i)\mathcal W^{(i)} rather than the largest singularity in W\mathcal W.

What we see, then, is that phase transitions during learning correspond to discrete changes in the geometry of the "local" (=restricted) loss landscape. The expected behavior for models in these sets is determined by the largest nearby singularities.

phase-transitions 1.png

In a Bayesian learning process, the singularity becomes progressively simpler with more data. In general, learning processes involve trading off a more accurate fit against "regularizing" singularities. Based on Figure 7.6 in [1].

In this light, the link with physics is not just the typical arrogance of physicists asserting themselves on other people's disciplines. The link goes much deeper.

Physicists have known for decades that the macroscopic behavior of the systems we care about is the consequence of critical points in the energy landscape: global behavior is dominated by the local behavior of a small set of singularities. This is true everywhere from statistical physics and condensed matter theory to string theory. Singular learning theory tells us that learning machines are no different: the geometry of singularities is fundamental to the dynamics of learning and generalization.

Neural networks are freaks of symmetries

The trick behind why neural networks generalize so well is something like their ability to exploit symmetry. Many models take advantage of the parameter-function map not being one-to-one. Neural networks take this to the next level.

There are discrete permutation symmetries, where you can flip two columns in one layer as long as you flip the two corresponding rows in the next layer, e.g.,

(abcdefghi)(jklmnopqr)=(bacedfhgi)(mnojklpqr). \begin{pmatrix} \color{red} a & \color{blue} b & c\\ \color{red} d & \color{blue}e & f \\ \color{red}g & \color{blue} h & i \end{pmatrix} \cdot \begin{pmatrix} \color{red}j & \color{red}k & \color{red}l\\ \color{blue}m & \color{blue}n & \color{blue}o \\ p & q & r \end{pmatrix} =\begin{pmatrix} \color{blue} b & \color{red} a & c\\ \color{blue} e & \color{red}d & f \\ \color{blue}h & \color{red} g & i \end{pmatrix} \cdot \begin{pmatrix} \color{blue}m & \color{blue}n & \color{blue}o \\ \color{red}j & \color{red}k & \color{red}l\\ p & q & r \end{pmatrix}.

There are scaling symmetries associated to ReLU activations,

ReLU(x)=1αReLU(αx),α>0, \text{ReLU}(x) =\frac{1}{\alpha}\text{ReLU}(\alpha x),\quad \alpha > 0,

and associated to layer norm,

LayerNorm(αx)=LayerNorm(x),α>0. \text{LayerNorm}(\alpha x) = \text{LayerNorm}(x),\quad \alpha > 0.

(Note: these are often broken by the presence of regularization.)

And there's a GLnGL_n symmetry associated to the residual stream (you can multiply the embedding matrix by any invertible matrix as long as you apply the inverse of that matrix before the attention blocks, the MLP layers, and the unembedding layer, and if you apply the matrix after each attention block and MLP layer).

But these symmetries aren't actually all that interesting. That's because they're generic. They're always present for any choice of ww. The more interesting symmetries are non-generic symmetries that depend on ww.

It's the changes in these symmetries that correspond to phase transitions in the posterior; this is the mechanism by which neural networks are able to change their effective dimensionality.

These non-generic symmetries include things like a degenerate node symmetry, which is the well-known case in which a weight is equal to 00 and performs no work, and a weight annihilation symmetry in which multiple weights are non-zero but combine to have an effective weight of zero.

The consequence is that even if our optimizers are not performing explicit Bayesian inference, these non-generic symmetries allow the optimizers to perform a kind of internal model selection. There's a trade-off between lower effective dimensionality and higher accuracy that is subject to the same kinds of phase transitions as discussed in the previous section.

The dynamics may not be exactly the same, but it is still the singularities and geometric invariants of the loss landscape that determine the dynamics.

Discussion and limitations

All of the preceding discussion holds in general for any model where the parameter-function mapping is not one-to-one. When this is the case, singular learning theory is less a series of interesting and debate-worthy conjectures than a necessary frame.

The more relevant question is whether this theory actually tells us anything useful in practice. Quantities like the RLCT are exceedingly difficult to calculate for realistic systems, so can we actually put this theory to use?

I'd say the answer is a tentative yes. Results so far suggest that the predictions of SLT hold up to experimental scrutiny — the predicted phase transitions are actually observable for small toy models.

That's not to say there aren't limitations. I'll list a few from here and a few of my own.

Before we get to my real objections, here are a few objections I think aren't actually good objections:

  • But we care about function-approximation. This whole discussion is couched in a very probabilistic context. In practice, we're working with loss functions and are approximating functions, not densities. I don't think this is much of a problem as it's usually possible to recover your Bayesian footing in deterministic function approximation. Even when this isn't the case, the general claim — that the geometry of singularities determine dynamics — seems pretty robust.
  • But we don't even train to completion! (/We're not actually reaching the minimum loss solutions). I expect most of the results to hold for any level set of the loss landscape — we'll just be interested in the dominant singularities of the level sets we end up in (even if they don't perfectly minimize the loss).
  • But calculating (and even approximating) the RLCT is pretty much intractable. In any case, knowing of something's theoretical existence can often help us out on what may initially seem like unrelated turf. A more optimistic counter would be something like "maybe we can compute this for simple one-layer neural networks, and then find a straightforward iterative scheme to extend it to deeper layers." And that really doesn't seem all too unreasonable — when I see all the stuff physicists can squeeze out of nature, I'm optimistic about what learning theorists can squeeze out of neural networks.
  • But how do you adapt the results from tanh\tanh to realistic activations like swishes? In the same way that many of the universal approximation theorems don't depend on the particulars of your activation function, I don't expect this to be a major objection to the theory.
  • But ReLU networks are not analytic. Idk man, seems unimportant.
  • But what do asymptotic limits in nn actually tell us about the finite case? I guess it's my background in statistical physics, but I'd say that a few trillion tokens is a heck of a lot closer to infinity than it is to zero. In all seriousness, physics has a long history of success with finite-size scaling and perturbative expansions around well-behaved limits, and I expect these to transfer.
  • But isn't this all just a fancy way of saying it was broad basins this entire time? Yeah, so I owe you an apology for all the Hessian-shaming and introduction-clickbaiting. In practice, I do expect small eigenvalues to be a useful proxy to how well specific models can generalize — less than zeros, but not nothing. Overall, the question that SLT answers seems to be a different question: it's about why we should expect models on average (and up to higher order moments) to generalize.

My real objections are as follows:

  • But these predictions of "generalization error" are actually a contrived kind of theoretical device that isn't what we mean by "generalization error" in the typical ML setting. Pretty valid, but I'm optimistic that we can find the quantities we actually care about from the ones we can calculate right now .
  • But what does Bayesian inference actually have to do with SGD and its variants? This complaint seems rather important especially since I'm not sold on the whole NNs-are-doing-Bayesian-inference thing. I think it's conceivable that we can find a way to relate any process that decreases free energy to the predictions here, but this does remain my overall biggest source of doubt.
  • But the true distribution is not realizable. For the above presentation, we assumed there is some choice of parameters w0w_0 such that p(xw0)p(x|w_0) is equal to q(x)q(x) almost everywhere (this is "realizability" or "grain of truth"). In real-world systems, this is never the case. For renormalizable6 models, extending the results to the non-realizable case turns out to be not too difficult. For non-renormalizable theories, we're in novel territory.

Where Do We Go From Here?

I hope you've enjoyed this taster of singular learning theory and its insights: the sense of learning theory as physics with likelihoods, of learning as the thermodynamics of loss, of generalization as the presence of singularity, and of the deep, universal relation between global behavior and the local geometry of singularities.

The work is far from done, but the possible impact for our understanding of intelligence is profound. 

To close, let me share one of directions I find most exciting — that of singular learning theory as a path towards predicting the scaling laws we see in deep learning models [5].

There's speculation that we might be able to transfer the machinery of the renormalization group, a set of techniques and ideas developed in physics to deal with critical phenomena and scaling, to understand phase transitions in learning machines, and ultimately to compute the scaling coefficients from first principles.

To borrow Dan Murfet's call to arms:

It is truly remarkable that resolution of singularities, one of the deepest results in algebraic geometry, together with the theory of critical phenomena and the renormalisation group, some of the deepest ideas in physics, are both implicated in the emerging mathematical theory of deep learning. This is perhaps a hint of the fundamental structure of intelligence, both artificial and natural. There is much to be done!

References

[1]: Watanabe 2009

[2]: Carroll 2021

[3]: Metauni 2021-2023 (Super awesome online lecture series hosted in Roblox that you should all check out.)

[4]: Guedj 2019

[5]: Kaplan 2020

Footnotes

  1. The dimensionality of the optimal parameters also depends on the true distribution generating your distribution, but even if the set of optimal parameters is zero-dimensional, the presence of level sets elsewhere can still affect learning and generalization.

  2. And from the underlying true distribution.

  3. To be precise, this rests on the assumption of realizability — that there is some weight w0w_0 for which p(xw0)p(x|w_0) equals q(x)q(x) almost everywhere. In this case, the minimum value of the negative log likelihood is the empirical entropy.

  4. They are, of course, wrong.

  5. So nn is really a kind of inverse temperature, like β\beta. Increasing the number of samples decreases the effective temperature, which brings us closer to the (degenerate) ground state.

  6. A word with a specific technical sense but that is related to renormalization in statistical physics.

Beyond Bayes

Context: We want to learn an appropriate function ff provided samples from a dataset Dn={(X,Y)}nD_n = \{(X, Y)\}^n.

Turns out, you can do better than the naive Bayes update,

P(fDn)=P(Dnf) P(f)P(Dn). P(f|D_n) = \frac{P(D_n| f)\ P(f)}{P(D_n)}.

Tempered Bayes

Introduce an inverse temperature, β\beta, to get the tempered Bayes update [1]:

Pβ(fDn)=P(Dnf)β P(f)Pβ(Dn). P_\beta(f|D_n) = \frac{P(D_n|f)^{\beta}\ P(f)}{P_{\beta}(D_n)}.

At first glance, this looks unphysical. Surely P(AB)β P(B)=P(A,B)P(A|B)^{\beta}\ P(B) = P(A, B) only when β=1\beta=1?

If you're one for handwaving, you might just accept that this is just a convenient way to vary between putting more weight on the prior and more weight on the data. In any case, the tempered posterior is proper (integrable to one), as long as the untempered posterior is [2].

If you're feeling more thorough, think about the information. Introducing an inverse temperature is simply scaling the number of bits contained in the distribution. P(X,Yf)=exp{βI(X,Yf)}P(X, Y|f) = \exp\{-\beta I(X, Y|f)\}.

TODO: Check out Grünwald's Safe Bayes papers

Generalized Bayes

If you're feeling even bolder, you might replace the likelihood with a general loss term, β,n(f)\ell_{\beta, n}(f), which measures performance on your dataset DnD_n,

Pβ(fDn)=β,n(f) P(f)Zn, P_\beta(f|D_{n)}= \frac{\ell_{\beta,n}(f)\ P(f)}{Z_n},

where we write the normalizing constant or partition function as ZnZ_n to emphasize that it isn't really an "evidence" anymore.

The most natural choice for β,n\ell_{\beta,n} is the Gibbs measure:

β,n(f)=exp{βrn(f)}, \ell_{\beta, n}(f) = \exp\left\{-\beta\, r_n(f)\right\},

where rnr_n is the empirical risk of classical machine learning. You can turn any function into a probability.

Jeffreys Updates

TODO