Notes

Generalization

NOTE: Work in Progress

Outline

  • Warm-up: MSE & bias-variance trade-offs.
  • Generalization = Simplicity.
    • Solomonoff induction,
    • Complexity measures
      • BIC/MDL, AIC
      • VC dimension & structural risk minimization
      • Rademacher complexity
  • Early Stopping
    • NTK
  • Problems:
    • Vacuous bounds (Zhang et al. 2017). Always zero training error. Any optimizer (even stupid ones). Crazily convex loss spots. Even with regularization
  • From functions to probabilities
  • Mingard-style “Bayesian inference”

Misc:

  • SGD & Basin broadness
  • Heavy-tailed random matrix theory (implicit self-regularization)
  • Double descent & grokking
  • Recent Anthropic work It’s non-obvious that (l2) regularization chooses simpler functions in NNs. (Polynomial regression is one thing).

"Regular" learning theory (RLT) predicts that overparametrized deep neural networks (DNNs) should overfit. In practice, the opposite happens: deeper networks perform better and generalize further. What's going on?

The consensus is still out, but different strands of research appear to be converging on a common answer: DNNs don't overfit because they're not actually overparametrized; SGD and NN's symmetries favor "simpler" models whose effective parameter counts are much lower than the externally observed parameter count. Generality comes from a kind of internal model selection where neural networks throw out unnecessary expressivity.

In this post, I'll make that statement more precise. We'll start by reviewing what RLT has to say about generalization and why it fails to explain generalization in DNNs. We'll close with a survey of more recent attempts at the learning theory of DNNs with particular emphasis on "singular" learning theory (SLT).

Intro to Statistical Learning Theory

Before we proceed, we need to establish a few definitions. Statistical learning theory is often sloppy with its notation and assumptions, so we have to start at the basics.

In this post, we'll focus on a (self-)supervised learning task. Given an input xXx \in \mathcal X, we'd like to predict the corresponding output yYy \in \mathcal Y. This relationship is determined by an unknown "true" pdf q(z)=q(x,y)q(z) = q(x, y) over the joint sample space Z=X×Y\mathcal Z = \mathcal X \times \mathcal Y.

If we're feeling ambitious, we might directly approximate the true joint distribution with a generative model p(x,yw)p(x, y|w) that is parametrized by weights wWRDw \in \mathcal W \subseteq \mathbb R^D. If it's tractable, we can marginalize this joint distribution to obtain the marginal and conditional distributions. 1

If we're feeling less ambitious, we might go ahead with a discriminative model of the conditional distribution, p(yx,w)p(y| x, w). (In the case of regression and classification, we're interested in inferring target values and labels from future samples, so we may not care all that much about q(x)q(x) and q(xy)q(x|y).)

And if we're feeling unambitious (as most machine learning theorists are), it may be enough to do direct inference with a discriminant function — i.e., to find a deterministic function that maps inputs to outputs, fw:XYf_w: \mathcal X \to \mathcal Y.2

In the case of regression, this is justified on the (only occasionally reasonable) assumption that the error is i.i.d. sampled from a Gaussian distribution. In other words, the discriminative model is defined in terms of ff as:

p(yx,w)=N(yfw(x),ϵ).p(y|x, w) = \mathcal N(y | f_{w}(x), \epsilon).

Machine learning usually begins with discriminant functions. Not only is this often easier to implement, but it offers some additional flexibility in the kinds of functions we can model because we optimize arbitrary loss functions that are unmoored from Bayesian-grounded likelihoods or posteriors.3

This flexibility comes at a cost: without the theoretical backing, it's more difficult to reason systematically about what's actually going on. The loss functions and regularizers of machine learning often end up feeling ad hoc and unsupported because they are. Or at least they will be until section 3.

As a result, we'll find it most useful to work at the intermediate level of analysis, in terms of discriminant models.

What Is Learning, Anyway?

The strength of the Bayesian framing is that for some prior over weights, ϕ(w)\phi(w), and a dataset d(n)Zn\mathbf d^{(n)} \in \mathcal Z^n of samples {(xi,yi)}i=1n\{(x_i, y_i)\}_{i=1}^n, we can "reverse" the likelihood (assuming each sample is sampled i.i.d. from p(x,y)p(x, y)),

p(d(n)w)=i=1np(xi,yiw),p(\mathbf d^{(n)}|w) = \prod_{i=1}^{n}p(x_{i}, y_{i}|w),

to obtain the a posteriori distribution,

p(wd(n))=1Znp(d(n)w)ϕ(w),p(w|\mathbf d^{(n)}) = \frac{1}{Z_{n}}p(\mathbf d^{(n)}|w)\phi(w),

where ZnZ_n is the evidence (or partition function),

Zn=Wp(d(n)w)ϕ(w)dw.Z_{n}= \int_{\mathcal W} p(\mathbf d^{(n)}|w)\,\phi(w)\,\mathrm dw.

The aim of "learning" is to make our model p(yx,w)p(y|x, w) as "close" as possible to the truth q(yx)q(y|x).

In the probabilistic formulation, the natural choice of "distance" is the Kullback-Leibler divergence (which is not actually a distance):

K(w):=DKL(p(x,y)q(x,yw))=log(q(yx,w)p(yx))X,Y,K(w) := D_\text{KL}\left(p(x, y)\,||\,q(x, y| w)\right) = \left\langle \log \left(\frac{q(y|x, w)}{p(y|x)}\right) \right\rangle_{X, Y},

where the second equality follows from defining q(x,yw)=q(yx,w)p(x)q(x, y|w) = q(y|x, w) \cdot p(x).

Formally, the aim of learning is to solve the following:

w=argminwK(w)w^* = \underset{w}{\operatorname{argmin}} K(w)

We say p(yx)p(y|x) is realizable iff K(w)=0K(w^*) = 0, in which case we call ww^* the "true" parameters.

At this point, we run into a bit of a problem: we can't actually compute the expectations X,Y\langle \cdot\rangle_{X, Y} since we don't know p(x,y)p(x, y) (and even if we did, the integration would likely be intractable). Instead, we have to resort to empirical averages over a dataset,. 3

Maximum Likelihood Estimation is KL-Divergence Minimization

Example: Regression & Mean-Squared Error

Fortunately, there's a more general and principled way to recover discriminant models from discriminant functions than the assumption of isotropic noise, as long as we're willing to relax Bayes' rule.

First, we introduces a parameter β0\beta \geq 0 that controls the tradeoff between prior and likelihood to obtain the tempered Bayes update [1],4

pβ(wd)p(dw)βp(w),p_\beta(w|\mathbf d) \propto p(\mathbf d| w)^{\beta}p(w),

where d\mathbf d is a dataset of samples {(xi,yi)}i=1n\{(x_{i},y_{i})\}_{i=1}^n, and the likelihood is obtained by assuming each sample is distributed i.i.d. according to p(xi,yiw)p(x_{i}, y_{i}|w) s.t. p(dw)=i=1np(xi,yiw)p(\mathbf d|w) = \prod_{i=1}^{n}p(x_{i}, y_{i}|w).

Next, we replace the tempered likelihood with a general loss function

The real-world is noisy, and pure function approximation can't account for this.

To remedy this, it's common to model any departure from our deterministic predictions, y^=fw(x)\hat y = f_w(x), as isotropic noise, e.g., q(yx,w)=N(yfw(x),σ2)q(y|x, w) = \mathcal N(y| f_w(x), \sigma^2).

Physically, we may expect the underlying process to be deterministic (given by a "true" function ff^*) with noise creeping in through some unbiased measurement error ϵN(0,σ2)\epsilon \sim \mathcal N(0, \sigma^2).

For Gaussian noise, the NLL works out to:

Ld(w)=12σ2Ni=1N(yiy^i)2+const.L_\mathbf d(w) = \frac{1}{2\sigma^2N}\sum_{i=1}^N (y_i-\hat y_i)^2 + \text{const}.

Assuming isotropic Gaussian noise in a regression setting, we see that MLE simplifies to minimizing the mean squared error (up to some overall constant and scaling).

To make MLE more Bayesian, you just multiply the likelihood by some prior over the weights, to obtain maximum a posteriori estimation (MAP): p(wd)p(dw)ϕ(w).p(w|\mathbf d) \propto p(\mathbf d|w)\phi(w). If we enforce a gaussian prior with precision α\alpha, the loss gains a regularization term (weight decay), α2wTw\frac{\alpha}{2}w^T w to the loss function. [2]

Gibbs Generalization Error

Todo

What Is Generalization, Anyway?

Consider the empirical probability distribution,

q(zd)=1Nzidδ(zzi), q(z|\mathbf d) = \frac{1}{N}\sum_{z_i \in \mathbf d} \delta(z-z_i),

where δ()\delta(\cdot) is a delta function appropriate to the sample space.

Although this will converge to the true distribution as NN\to\infty, for any finite NN, we encounter sampling error. This problem is particularly pronounced for unseen samples: qD(z)q_D(z) may (and almost always does) assign zero probability to samples that have non-zero probability under the true distribution.

The fundamental challenge of generalization is to make predictions about these unseen samples. Let's make this more precise.

Given some loss function :Y×YR\ell: Y \times Y \to \mathbb R, we're interested in how \ell performs on future samples, as measured by the expected risk or generalization error,

I : FRf(f(x),y)X,Y.\begin{align} I\ :\ \mathcal F &\to \mathbb R \\ f &\mapsto \left\langle \ell(f(x), y)\right\rangle_{X,Y}.\\ \end{align}

However, we can only estimate this performance on on the available dataset via the empirical risk,

Id : FRf(f(x),y),\begin{align} I_\mathbf d\ :\ \mathcal F &\to \mathbb R \\ f &\mapsto \overline{\ell(f(x), y)},\\ \end{align}

where \overline \cdot denotes an empirical average over the dataset, d\mathbf d.

Our problem is that the true generalization error is unknowable. So, in practice, we split our dataset into a training set, s\mathbf s, and test set, e\mathbf e, respectively. We learn our parameters ww via the training set, and then estimate the generalization error on the test set. That's where we encounter our first major confusion.3

You see, besides the generalization error, there's the distinct notion of the generalization gap,

ϵg=I[fwS=s]I[fwE=e]. \epsilon_g = I[f_w|\mathbf S = \mathbf s] - I[f_w|\mathbf E = \mathbf e].

Whereas generalization error is about absolute performance on the true distribution, the generalization gap is about relative performance between the training and test sets. Generalization error combines both "performance" and "transferability" into one number, while the generalization gap is more independent of "performance" — past a certain threshold, any amount of test-set performance is compatible with both a low and high generalization gap.

For the remainder of this post, I'll focus on the generalization gap. It's not absolute performance we're interested in: universal approximation theorems tell us that we should expect excellent training performance for neural networks. The less obvious claim is why this performance should transfer so well to novel samples.


Yes, more data means better generalization, but that's not what we're talking about.

In the limit NN\to \infty, the empirical risk converges to expected risk, and both generalization error and gap will fade away: ϵg0,ϵg20asN.\epsilon_g \to 0,\quad \epsilon_g^2 \to 0 \quad \text{as} \quad N \to \infty. So obviously the larger datasets that have accompanied the deep learning boom explain some of the improvement in generalization.

PAC Learning

Established by Valiant (1984) [3], Probably Approximately Correct (PAC) learning establishes upper bounds for the risk I[f]I[f]. In its simplest form, it states that, for any choice of ϵ\epsilon ("probably") a predictor, ff, will have its risk bounded by some δ\delta ("approximately correct"):

P[I[f]δ]1ϵ.\mathbb{P}[I[f] \leq \delta] \geq 1-\epsilon.

Originally, PAC learning included the additional assumption that the predictor was polynomial in NN and 1/ϵ1/\epsilon, but this has been relaxed to refer to any bound holding with high probability.

Still, it's not a full explanation as typical networks can easily memorize the entire training set (even under random labelings [4]).


MSE & Bias-variance tradeoff

In the case of isotropic noise (where our loss function is the MSE), the generalization error of a model is:

Generalization error=(yy^)2X,Y=Bias2+Variance+Irreducible error.\text{Generalization error} = \left\langle(y - \hat y)^2 \right\rangle_{X,Y} = \text{Bias}^2 + \text{Variance} + \text{Irreducible error}.

In other words, it is the (l2) distance between our predictions and the truth averaged over the true variables X×YX \times Y. The bias-variance decomposition splits the generalization error into a bias term,

Bias=fw(x)yX,Y,\text{Bias} = \langle f_w(x)-y\rangle_{X,Y},

which measures the average difference between predictions and true values, a variance term,

Variance=(fw(x)X,Yfw(x))2X,Y,\text{Variance} = \left\langle \left(\langle f_w(x)\rangle_{X,Y} - f_w(x)\right)^2\right\rangle_{X, Y},

which measures the spread of a model's predictions for a given point in the input space, and an irreducible error due to the inherent noise in the data. (For Gaussian noise, the irreducible error is σ2\sigma^2.)

In order to achieve good generalization performance, a model must have low bias and low variance. This means that the model must be complex enough to capture the underlying patterns in the data, but not so complex that it overfits the data and becomes sensitive to noise. In practice, these two forces are at odds, hence "tradeoff."


Generalization is about simplicity

A common assumption across the entire statistical learning theory literature is that of simple functions generalizing better. It's worth spending a second to understand why we should expect this assumption to hold.

Occam's Razor

"Entities must not be multiplied beyond necessity".

Occam's razor is the principle of using the simplest explanation that fits the available data. If you're reading this, you probably already take it for granted.

Solomonoff Induction

Algorithmic complexity theory lets us define a probability distribution over computable numbers,

p(x)=p2K(p),p(x)= \sum_{p}2^{-K(p)}, where K()K(\cdot) is the (prefix) Kolmogorov complexity. The probability of any given number is the probability that a uniform random input tape on some Universal Turing Machine outputs that number. Unfortunately, the halting problem makes this just a tad uncomputable. Still, in principle, this allows a hypercomputer to reason from the ground up what the next entry in a given sequence of numbers should be.

Bayesian/Akaike Information Criterion

One of the main strengths of the Bayesian frame is that it lets enforce a prior φ(w)\varphi(w) over the weights, which you can integrate out to derive a parameter-free model:

p(yx)=wp(yx,w)φ(w) dw.p(y|x) = \int_w p(y|x, w)\varphi(w)\ \text{d}w.

One of the main weaknesses is that this integral is often almost always intractable. So Bayesians make a concession to the frequentists with a much more tractable Laplace approximation (i.e., you approximate your model as quadratic/gaussian in the vicinity of the maximum likelihood estimator (MLE), w(0)w^{(0)}):

K(w)12(ww(0))TI(w(0))(ww(0)), K(w) \approx \frac{1}{2}(w-w^{(0)})^T I(w^{(0)}) (w-w^{(0)}),

where I(w)I(w) is the Fisher information matrix:

Ij,k(w)=(wjlogp(yx,w))(wklogp(yx,w))X,Yw=w.I_{j,k}(w)=\left\langle\left(\frac{\partial}{\partial w_j} \log p(y|x, w)\right)\left(\frac{\partial}{\partial w_k} \log p(y |x, w)\right) \right\rangle_{X,Y|w = w}.

From this approximation, a bit more math gives us the Bayesian information criterion (BIC):

BIC=Ln(w0)+D2logn.\text{BIC} = L_n(w_0) + \frac{D}{2}\log n.

The BIC (like the related Akaike information criterion) is a criterion for model selection that penalizes complexity. Given two models, the one with the lower BIC tends to overfit less (/"generalize better").

That is, simpler models (with fewer parameters) are more likely in approximate Bayesian inference. We'll see in the section on singular learning theory that the BIC is unsuitable for deep neural networks, but a generalized version, the Widely Applicable Bayesian Information Criterion (WBIC) will pick up the slack.

Minimum Description Length

The BIC is formally equivalent to MDL. TODO

Maximum Entropy Modeling

TODO

Classic Learning Theory

As we've seen in the previous section, the question of how to generalize reduces to the question of how to find simple models that match the data.

In classical learning theory, the answer to this is easy: use fewer parameters, include a regularization term, and enforce early stopping. Unfortunately, none of these straightforwardly help us out with deep neural networks.

For one, double descent tells us that the relation between parameter count and generalization is non-monotonic: past a certain number of parameters, generalization error will start decreasing. Moreover, a large number of the complexity measures we've seen are extensive (they scale with model size). That suggests they're not suited to the task.

As for regularization, it gets more complicated. That regularization on polynomial regression leads to simpler functions is straightforward. That regularization on deep neural networks leads to simpler models is less obvious. Sure, a sparse l1 regularizer that pushes weights to zero seems like it would select for simpler models. But a non-sparse l2 regularizer? Linking small parameters to simple functions will require more work. A deeper problem is that DNNs can memorize randomly labeled data even with regularizers [5]. Why then, do DNNs behave so differently on correctly labeled (/well-structured) data

Other forms of regularization like dropout are more understandable: they explicitly select for redundancy which collapses the effect parameter count.

Finally, early stopping doesn't seem to work with the observation of grokking: in certain models, training loss and test loss may plateau (at a poor generalization gap) for many epochs before training loss suddenly sharply decreasing (towards a much better generalization gap).

SGD Favors Flat Minima

  • Pβ(f)P_\beta(f) is the probability that MM expresses ff on DD upon a randomly sampled parametrization. This is our "prior"; it's what our network expresses on initialization.
  • Vβ(f)V_\beta(f) is a volume with Gaussian measure that equals Pβ(f)P_\beta(f) under Gaussian sampling of network parameters.
    • This is a bit confusing. We're not talking about a continuous region of parameter space, but a bunch of variously distributed points and lower-dimensional manifolds. Mingard never explicitly points out why we expect a contiguous volume. That or maybe it's not necessary for it to be contiguous
  • β\beta denotes the "Bayesian prior"
  • Popt(fS)P_\text{opt}(f|S) is the probability of finding ff on EE under a stochastic optimizer like SGD trained to 100% accuracy on SS.
  • Pβ(fS)=P(Sf)Pβ(f)Pβ(S),P_\beta(f|S) = \frac{P(S|f) P_\beta(f)}{P_\beta(S)}, is the probability of finding ff on EE upon randomly sampling parameters from i.i.d. Gaussians to get 100% accuracy on SS.
    • This is what Mingard et al. call "Bayesian inference"
    • P(Sf)=1P(S|f)=1 if ff is consistent with SS and 00 otherwise I.e.: Hessians with small eigenvalues.

Flat minima seem to be linked with generalization:

  • Hochreiter and Schmidhuber, 1997a; Keskar et al., 2016; Jastrzebski et al., 2018; Wu et al., 2017; Zhang et al., 2018; Wei and Schwab, 2019; Dinh et al., 2017.

SGD

Problems:

  • Suitable reparametrizations can change flatness without changing computation.

Initialization Favors "simple" functions

Mingard 2021

  • Levin et al. tell us that many real-world maps satisfy P(f)2bK~(f)+aP(f) \lesssim 2^{-b\tilde K(f)+a}, where K^\hat K is a computable approximation of the true Kolmogorov complexity K(f)K(f).
  • Empirically, maps of the form f:{0,1}n{0,1}f:\{0, 1\}^n \to \{0, 1\} satisfy a similar upper bound using a computable complexity measure (CSR).
  • This is an upper bound, not more than that!
  • The initialization acts as our prior in Bayesian inference.

SGD Performs Hidden Regularization

Mingard 2020

  1. **Bayesian inference preserves the "simplicity" of the prior.
  2. SGD performs a kind of "Bayesian inference"
    • You can approximate Pβ(f)P_\beta(f) with Gaussian Processes.
    • Mingard's main result is that Popt(fS)Pβ(fS)P_\text{opt}(f|S) \approx P_\beta(f|S) appears to hold for many datasets (MNIST, Fashion-MNIST, IMDb movie review, ionosphere), architectures (Fully connected, Convolutional, LSTM),  optimizers (SGD, Adam, etc.), training schemes (including overtraining) and optimizer hyperparameters (e.g. batch size, learning rate).
    • Optimizer hyperparameters matter much less than Pβ(fS)P_\beta(f|S)

200 200

Singular Learning theory

Glossary

  • ff^* is the true function
  • ff is some implemented function
  • MM is our neural network
  • DD is a dataset of pairs {(xi,yi)}i=1N\{(x_i, y_i)\}_{i=1}^N
  • S,EDS, E \subset D are the training & test sets, respectively.

Learnability

  • PAC in original formulation: simpler functions are easier to learn (polynomial time)

Usually, the true relation is probabilistic. In this case, we're not interested in a deterministic mapping ff^* from elements of X\mathcal X to elements Y\mathcal Y, but a probability distribution, q(x,y)q(x, y), which relates random variables XX and YY (where X\mathcal X and Y\mathcal Y are the associated sample spaces).

We don't have direct access to q(x,y)q(x, y), but we do have access to a dataset of samples, D={(Xi,Yi)}i=1n\mathbf D = \{(X_i, Y_i)\} _ {i=1}^n, which is itself a random variable. We specify some loss function, \ell, Y×YR\mathcal Y \times \mathcal Y \to\mathbb R, which maps a prediction, y^=fw(x)\hat y = f_w(x), and true value, y=f(x)y = f^*(x), to a "loss". Assuming,

Usually, the true relation is probabilistic. In this case, we're not interested in a deterministic relation ff^*, but a probabilistic ground truth, given by some distribution, q(x,y)q(x, y).

Formally, we're trying to find the weights, w^\widehat w, that minimize the expected risk ("generalization error") for some choice of loss function, :Y×YR\ell: \mathcal Y \times \mathcal Y \to \mathbb R,

w^:=argminwWEX,Y[(fw(X),Y)]. \widehat w := \underset{w \in \mathcal W}{\operatorname{argmin}} \mathbb E_{X,Y}[ \ell(f_w(X), Y)].

We don't have direct access to the probability distribution, P(X,Y)P(X, Y), generating our data, so we approximate the expectation with an empirical average over a dataset D={(Xi,Yi)}i=1n\mathbf D = \{(X_i, Y_i)\} _ {i=1}^n to get the empirical risk or "test loss". 2 More accurately, we optimize w^\widehat w on a training set, Dtrain\mathbf D_\text{train}, but report performance on a test set, Dtest\mathbf D_\text{test}, to avoid sampling bias and overfitting, where D=DtrainDtest\mathbf D = \mathbf D_\mathrm{train} \cup \mathbf D_\mathrm{test} and DtrainDtest=\mathbf D_\mathrm{train} \cap \mathbf D_\mathrm{test} = \emptyset. 3

Footnotes

  1. Equivalently, we could separately model the likelihood p(xy)p(x|y) and prior p(y)p(y), then multiply to get the joint distribution. P.S. I prefer using qq and pp to mean the model and truth, respectively, but I'm keeping to the notation of Watanabe in Algebraic Geometry and Singular Learning Theory.

  2. Traditionally, statistical learning theory drops the explicit dependence on ww, and instead looks at model selection at the level of fFf \in \mathcal F, qQq \in \mathcal Q. As we're interested in neural networks, we'll find it useful to fix a particular functional form of our model and look at selection of suitable parameters in ww. Later, we'll see that the understanding the mapping wfww \to f_w is at the heart of understanding why deep neural networks work as well as they do. 2

  3. From Guedj [1]: "The past few decades have thus seen an increasing gap between the Bayesian statistical literature, and the machine learning community embracing the Bayesian paradigm – for which the Bayesian probabilistic model was too much of a constraint and had to be toned down in its influence over the learning mechanism. This movement gave rise to a series of works which laid down the extensions of Bayesian learning[.]" 2 3 4

  4. . TODO: Something something Safe Bayes

Path dependence

NOTE: Work in Progress

How much do the models we train depend on the path they follow through weight space?

Should we expect to always get the same models for a given choice of hyperparameters and dataset? Or do the outcomes depend highly on quirks of the training process, such as the weight initialization and batch schedule?

If models are highly path-dependent, it could make alignment harder: we'd have to keep a closer eye on our models during training for chance forays into deception. Or it could make alignment easier: if alignment is unlikely by default, then increasing the variance in outcomes increases our odds of success.

Vivek Hebbar and Evan Hubinger have already explored the implications of path-dependence for alignment here. In this post, I want to take a step back and survey the literature on path-dependence in current models. I'll also share the results of a few experiments on simple datasets like MNIST.

Formalizing Path-dependence

Hebbar and Hubinger define path-dependence "as the sensitivity of a model's behavior to the details of the training process and training dynamics." Let's make that definition more precise. To start, let us formalize the learning process as a dynamical system over weight space.

In the context of machine learning, our goal is (usually) to model some "true" function f:XYf^* : \mathcal X \to \mathcal Y1. To do this, we construct a neural network, Ff:X×WY\mathcal F \ni f: \mathcal X \times \mathcal W \to \mathcal Y, which induces a model of ff^* upon fixing a choice of weights, wWRDw \in \mathcal W \subset \mathbb R^D. For convenience, we'll denote the resulting model with a subscript, i.e., fw(x):=f(x,w)f_{w}(x):= f(x, w).2

To find the weights, woptw_\mathrm{opt}, such that fopt:=fwoptf_\text{opt} := f_{w_\text{opt}} is as "close" as possible to ff^*, we specify a loss function LnL_n, which maps a choice of parameters and a set of training examples, dtrain={(xi,yi)}i=1nZn\mathbf d_\text{train} = \{(x_{i},y_i)\}_{i=1}^n \in \mathcal Z^n, to a scalar "loss"3:

Ln:W×ZnR.L_n: \mathcal W \times \mathcal Z^n \to \mathbb R.

For a fixed choice of dataset, we denote the empirical loss, Ltrain(w)=L(w,dtrain)L_\text{train}(w) =L(w, \mathbf d_\text{train}). Analogously, we can define a test loss, LtestL_\text{test} over a corresponding test set, and batch loss, L(t)(w)L^{(t)}(w), for a batch b(t)dtrain\mathbf b^{(t)} \subset \mathbf d_\text{train}. Next, we choose an optimizer, Φ\Phi, which iteratively updates the weights (according to some variant of SGD),

Φ:W×B×HW(w(t), b(t), h)w(t+1). \begin{align} \Phi:\mathcal W \times \mathcal B\times \mathcal H &\to \mathcal W\\ (w^{(t)},\ b^{(t)},\ h) &\mapsto w^{(t+1)}. \end{align}

The optimizer depends on hyperparameters, hHh \in \mathcal H, as well as some learning schedule b={b(t)}t=1T\mathbf b = \{\mathbf b^{(t)}\}_{t=1}^T of TT batches, b(t)B\mathbf b^{(t)} \in \boldsymbol{\mathcal B}. There are NepochsN_\text{epochs} epochs consisting of TepochsT_\text{epochs} batches (i.e., T/Tepochs=NepochsT/ T_\text{epochs} = N_\text{epochs}). The learning schedule for epoch τ\tau is the subset b={b(t)}t=τTepochsτ(Tepochs+1)1\mathbf b = \{\mathbf b^{(t)}\}_{t=\tau T_\text{epochs}}^{\tau (T_\text{epochs} + 1) - 1}. Within any given epoch, the batches are disjoint, but across epochs, repetition is allowed (though typically no batch will be repeated sample for sample).

Example: for Adam, h=(β1,β2,ϵ,η)h=(\beta_1, \beta_2, \epsilon, \eta). These are, respectively, the momenta for the first moment and second moment of the gradients, a small term to ensure numerical stability, and the learning rate.

If we take the learning schedule and hyperparameters as constant, we obtain a discrete-time dynamical system over parameter space, which we denote Φb,h:WW\Phi_{\mathbf b, h}: \mathcal W \to \mathcal W.

Two Kinds of Path Dependence

Given this dynamical system over weight space, let us contrast two kind of path dependence:

  1. Global path dependence. Given some distribution over starting weights, p(w(0))p(w^{(0)}), hyperparameters, p(h)p(h), or batches, p(b)p(\mathbf b), what is the resulting distribution over final weights, p(w(T))p(w^{(T)})?
  2. Local path dependence. For some choice of initial weights, w0w^0, hyperparameters, hh, or batches, b\mathbf b, and a small perturbation to one of these values, how different (e.g., in terms of the l2 norm) are the final weights of the perturbed model from the baseline unperturbed model?

The former enquiry is concerned with finding (semi-)stable distributions while the latter is concerned with chaos (macroscopic sensitivity to microscopic perturbations in deterministic systems is the definition of chaos). Though we ultimately want to form global statements of path dependence, the distributions involved are generally intractable (thus require us to resort to empirical approximations).

In the local view, we rarely care about specific perturbations, so we end up studying the relation between a distribution over perturbations p(δ)p(\delta) and p(w(T))p(w^{(T)}), which is similarly intractable. Still, this is a much smaller space to explore, which makes it friendlier to investigation.

From Weight Space to Function Space

The problem with studying dynamics over W\mathcal W is that we don't actually care about wWw \in \mathcal W. Instead, we care about the evolution in function space, F\mathcal F, resulting from the mapping m:WwfwFm: \mathcal W \ni w \mapsto f_{w}\in \mathcal F.

This is become the mapping to function space is non-injective; the internal symmetries of mm ensure that different choices of wWw\in \mathcal W can map to the same function fFf \in \mathcal F.[1] As a result, notions of similarity and distance in W\mathcal W can be misleading.

The difficulty with function space is that it is infinite-dimensional, which makes it that much more intractable than weight space. Though we have full knowledge of W\mathcal W, we can't resolve exactly where we are in F\mathcal F with a finite set of samples. At best, we can resolve the level set in F\mathcal F with a certain empirical performance (like training/test loss). Without a better understanding of the mapping mm, this leaves room for uncertainty between models that can perform wildly differently out-of-training or off-distribution.

On optimizer state: If we want to be exhaustive, let us recall that optimizers might have some internal state, sSs \in \mathcal S. E.g., for Adam, st=(mt,vt)s_{t}= (m_t, v_t), a running average of the first moment and second moment of the gradients, respectively. Then, the dynamical system over W\mathcal W is more appropriately a dynamical system over the joint space W×S\mathcal W \times \mathcal S. Completing this extension will have to wait for another day; we'll ignore it for now for the sake of simplicity and because, in any case, sts_t, tends to be a deterministic function of b\mathbf b and hh.

Before, we run a few experiments, let's recap what the literature has to say on path-dependence so far.

Review

Evidence of Low Path Dependence

Internal Symmetries

Ainsworth, Hayase, and Srinivasa [2] find that, after correcting for permutation symmetries, different weights are connected by a close-to-zero loss barrier linear mode connection. In other words, you can linearly interpolate between the permutation-corrected weights, and every point in the linearly interpolation has essentially the same loss. They conjecture that there is only global basin after correcting for these symmetries.

In general, correcting for these symmetries is NP-hard, so the argument of these authors depends on several approximate schemes to correct for the permutations [2].

![400](/media/pasted-image-20221220152152.png)

Universal Circuits

Some of the most compelling evidence for the low path-dependence world comes from the circuits-style research of Olah and collaborators. Across a range of computer vision models (AlexNet, InceptionV1, VGG19, ResnetV2-50), the circuits thread [3] finds features common to all of them such as curve and high-low frequency detectors [4], branch specialization [5], and weight banding [6]. More recently, the transformer circuits thread [7] has found universal features in language models, such as induction heads and bumps [8]. This is path independence at the highest level: regardless of architecture, hyperparameters, and initial weights different models learn the same things. In fact, low path-dependence ("universality") is often taken as a starting point for research on transparency and interpretability [4].

Pasted image 20221220150027.png Universal circuits of computer vision models [4].

ML as Bayesian Inference

  • Pβ(f)P_\beta(f) is the probability that MM expresses ff on DD upon a randomly sampled parametrization. This is our "prior"; it's what our network expresses on initialization.

  • Vβ(f)V_\beta(f) is a volume with Gaussian measure that equals Pβ(f)P_\beta(f) under Gaussian sampling of network parameters.

    • This is a bit confusing. We're not talking about a continuous region of parameter space, but a bunch of variously distributed points and lower-dimensional manifolds. Mingard never explicitly points out why we expect a contiguous volume. That or maybe it's not necessary for it to be contiguous
  • β\beta denotes the "Bayesian prior"

  • Popt(fS)P_\text{opt}(f|S) is the probability of finding ff on EE under a stochastic optimizer like SGD trained to 100% accuracy on SS.

  • Pβ(fS)=P(Sf)Pβ(f)Pβ(S),P_\beta(f|S) = \frac{P(S|f) P_\beta(f)}{P_\beta(S)}, is the probability of finding ff on EE upon randomly sampling parameters from i.i.d. Gaussians to get 100% accuracy on SS.

    • This is what Mingard et al. call "Bayesian inference"
    • P(Sf)=1P(S|f)=1 if ff is consistent with SS and 00 otherwise
  • Double descent & Grokking

  • Mingard et al.'s work on NNs as Bayesian.

Evidence of High Path Dependence

  • Why Comparing Single Performance Scores Does Not Allow to Conclusions About Machine Learning Approaches (Reimers et al., 2018)
  • Deep Reinforcement Laming Doesn’t Work Yet (Irpan, 2018)
  • BERTS of a feather do not flock together

Experiments

In section XX, we defined path dependence as about understanding the relation between (P(Bt),P(W0),P(H))(P(B_t), P(W_0), P(H)) and (P(WT),P(FT))(P(W_{T}), P(F_{T})). As mentioned, there are two major challenges:

  1. In general, we have to estimate these distributions empirically, and training neural networks is already computationally expensive, so we're restricted to studying smaller networks and datasets.
  2. The mapping m:WFm:\mathcal W\to \mathcal F is non-straightforward because of the symmetries of mm.

To make it easier for ourselves, let us restrict our attention to the narrower case of studying local path-dependence.

a (small) perturbation, ϵ\epsilon, to one of (bt,w0,h)(b_{t}, w_{0}, h). E.g., we'll study probability densities of the kind p(w0)=N(w0wbaseline,ϵ1)p(\mathbf w_0)=\mathcal N(\mathbf w_{0}|\mathbf w_\text{baseline}, \epsilon \mathbf 1). For discrete variables (like the number of layers, network width, and batch schedule), there's a minimum size we can make ϵ\epsilon,

Within dynamical systems theory, there are two main ways to view dynamical systems:

  1. The trajectory view studies the evolution of points like (w0,w1,,wT)(w_0, w_1, \dots, w_T) and (f1,f2,,fT)(f_1, f_2, \dots, f_T).
  2. The probability view studies the evolution of densities like (p0(w),p1(w),,pT(w))(p_0(w), p_1(w), \dots, p_T(w)) and (p1(f),p2(f),,pT(f))(p_1(f), p_2(f), \dots, p_T(f)).

Both have something to tell us about path dependence:

  1. In the trajectory view, sensitivity of outcomes becomes a question of calculating Lyapunov exponents, the rate at which nearby trajectories diverge/converge.
  2. In the probability view, sensitivity of outcomes becomes a question of measuring autocorrelations, wiwi+τ\langle w_{i} w_{i+\tau}\rangle and fifi+τ\langle f_{i} f_{i+\tau} \rangle.

Tracking Trajectories

The experimental set-up involves taking some baseline model, m0m_0, then applying a Gaussian perturbation to the initial weights with norm ϵ\epsilon, to obtain a set of perturbed models {mi}i=1nmodels\{m_i\} _{i=1}^{n _\text{models}}. We train these models on MNIST (using identical batch schedules for each model).

Over the course of training, we track how these models diverge via several metrics (see next subsection). For each of these, we study how the rate of divergence varies for different choices of hyperparameters: momentum β\beta, weight decay λ\lambda, learning rate η\eta, hidden layer width (for one-hidden-layer models), and number of hidden layers. Moving on from vanilla SGD, we compare adaptive variants like Adam and RMSProp.

TODO: Non-FC models. Actually calculate the rate of divergence (for the short exponential period at the start). Perform these experiments for several different batch schedules. Perturbations in batch schedule. Other optimizers besides vanilla SGD.

Measuring distance

  • dw(p)(mi,m0)d_\mathbf{w}^{(p)}(m_i, m_0): the pp-norm between the weight vectors of each perturbed model and the baseline model. We'll ignore the pp throughout and restrict to the case p=2p=2. This is the most straightforward notion of distance, but it's flawed in that it can be a weak proxy for distance in F\mathcal F because of the internal symmetries of the model.
  • dw(mi,m0)w0\frac{d_{\mathbf w}(m_{i}, m_0)}{|\mathbf w_{0}|}: the relative pp-norm between weight vectors of each perturbed model.
  • Ltrain,test(mi)L_\text{train,test}(m_i): the training/test losses
  • δLtrain, test(mi,m0)\delta L_\text{train, test}(m_{i}, m_0): the training/test loss relative (difference) to the baseline model m0m_0
  • ftrain, test(mi)f_\text{train, test}(m_i), δftrain, test(mi,m0)\delta f_\text{train, test}(m_i, m_0): the training/test set classification accuracy and relative classification accuracy.
  • Ltrain cf., test cf.L_\text{train cf., test cf.} and ftrain cf., test cf.f_\text{train cf., test cf.}: same as above, but we take the predictions of the baseline model as the ground truth (rather than the actual labels in the test set).

A few more metrics to include:

  • dwperm.d^\text{perm.}_{\mathbf w}: the l2 norm after adjusting for permutation differences (as described in Ainsworth et al.)
  • w0wiperm.Ltrain (cf.), test (cf.)(w)dw\int_{\mathbf w_0}^{\mathbf w_{i}^\text{perm.}} L_\text{train (cf.), test (cf.)}(\mathbf w) d\mathbf w. The loss integrated over the linear interpolation between two models' weights after correcting for permutations (as described in Ainsworth et al.)

Tracking Densities

TODO

Results

One-hidden-layer Models

What's remarkable about one-hidden-layer models is how little the model depends on weight initialization: almost all of the variance seems to be explained by the batch schedule. Even for initial perturbations of size ϵ=10\epsilon=101, the models appear to become almost entirely equivalent in function space across the entire training process. In the figure below, you can see that the differences in LtrainL_\text{train} are imperceptible (both for a fixed ϵ\epsilon and across averages over different ϵ\epsilon).

TODO: I haven't checked ϵ/w\epsilon/|\mathbf w|. I'm assuming this is >1 for ϵ=10\epsilon=10 (which is why I find this surprising), but I might be confused about PyTorch's default weight initialization scheme. So take this all with some caution.

400300

400

The rate of growth for dwd_\mathbf{w} has a very short exponential period (shorter than a single epoch), followed by a long linear period (up to 25 epochs, long past when the error has stabilized). In some cases, you'll see a slight bend upwards or downwards. I need more time to test over more perturbations (right now it's 10 perturbed models) to clean this up. Maybe these trends change over longer periods, and with a bit more time I'll test longer training runs and over more perturbations.

300300

Hyperparameters

  • Momentum: (β=0.1,0.5,0.9\beta=0.1, 0.5, 0.9) The dwd_\textbf w curves look slightly more curved upwards/exponential but maybe that's confirmation bias (I expect momentum to make divergence more exponential back when I expected exponential divergence). Small amounts of momentum appear to increase convergence (relative to other models equivalent up to momentum), while large amounts increase divergence (towards . For very small amounts, the effect appears to be negligible. Prediction: the dwd_w curves curve down because in flat basins, momentum slows you down.
  • Learning rate: Same for η=103,102,101\eta=10^{-3}, 10^{-2}, 10^{-1}. (Not too surprising) TODO: Compare directly between different learning rates. Can we see η\eta back in the slope of divergence? Prediction: This should not make much of a difference. I expect that correcting for the learning rate should give nearly identical separation slopes.
  • Weight Decay: TODO for λ=103,102,101\lambda=10^{-3}, 10^{-2}, 10^{-1}. Prediction: dwd_w will shrink (because w|w| will shrink). The normalized dw/wd_w/|w| will curve downwards because we break the ReLU-scaling symmetry which means the volume of our basins will shrink
  • Width: TODO. Both this and depth require care. With different sizes of w\mathbf w, we can't naively compare norms. Is it enough to divide by dimw\dim \mathbf w? Prediction: I expect the slope of divergence to increase linearly with the width of a one-layer network (from modeling each connection as separating of its own linear accord).
  • Depth: TODO. Prediction: I expect the shape of dwd_w for each individual layer to become more and more exponential with depth (because the change will be the product of layers that diverge linearly independently). I expect this to dominate the overall divergence of the networks.
  • Architecture: Prediction: I expect convolutional architectures to decrease the rate of separation (after correcting for the number of parameters).
  • Optimizer:
    • Prediction: I expect adaptive techniques to make the rate of divergence exponential.

Datasets

  • Computer Vision
    • MNIST
    • Fashion-MNIST
    • Imagenet: Prediction: I expect Imagenet to lead to the same observations as MNIST (after correcting for model parameter count).
  • Natural Language
    • IMDb movie review database

Why Linear?

I'm not sure. My hunch going in was that I'd see either exponential divergence (indicating chaos) or square root divergence (i.e., Brownian noise). Linear surprises me, and I don't yet know what to make of it.

Maybe all of this is just a fact about one-hidden-layer networks. If each hidden layer evolves independently linearly, maybe this combines additively into something Brownian. Or maybe it's multiplicative (so exponential).

Deep Models

TODO

Appendix

Weight Initialization

For ReLU-based networks, we use a modified version of Kaiming (normal) initialization, which is based on the intuition of Kaiming initialization as sampling weights from a hyperspherical shell of vanishing thickness.

Kaiming initialization is sampling from a hyperspherical shell

Consider a matrix, w(l)\mathbf w^{(l)}, representing the weights of a particular layer ll with shape (Din(l),Dout(l+1))(D_\mathrm{in}^{(l)}, D_\mathrm{out}^{(l+1)}). Din(l)D_\mathrm{in}^{(l)} is also called the fan-in of the layer, and Din(l+1)D_\mathrm{in}^{(l+1)} the fan-out. For ease of presentation, we'll ignore the bias, though the following reasoning applies equally well to the bias.

We're interested in the vectorized form of this matrix, w(l)RD(l)\vec w^{(l)} \in \mathbb R^{D^{(l)}}, where D(l)=Din(l)×Dout(l+1)D^{(l)} =D_\mathrm{in}^{(l)} \times D_\mathrm{out}^{(l+1)}.

In Kaiming initialization, we sample the components, wi(l)w_i^{(l)}, of this vector, i.i.d. from a normal distribution with mean 0 and variance σ2\sigma^2 (where σ2=2Din(l)\sigma^2 = \frac{2}{D_\mathrm{in}^{(l)}}).

Geometrically, this is equivalent to sampling from a hyperspherical shell, SD1S^{D-1} with radius Dσ\sqrt{D}\sigma and (fuzzy) thickness, δ\delta.

This follows from some straightforward algebra (dropping the superscript ll for simplicity):

E[w2]=E[i=1Dwi2]=i=1DE[wi2]=i=1Dσ2=Dσ2,\mathbb E[|\mathbf w|^2] = \mathbb E\left[\sum_{i=1}^D w_i^2\right] = \sum_{i=1}^D \mathbb E[w_i^2] = \sum_{i=1}^D \sigma^2 = D\sigma^2,

and

δ2var[w2]=E[(i=1Dwi2)2]E[i=1Dwi2]2=i,j=1DE[wi2wj2](Dσ2)2=ijDE[wi2]E[wj2]+i=1DE[wi4](Dσ2)2=D(D1)σ4+D(3σ4)(Dσ2)2=2Dσ4.\begin{align} \delta^2 \propto \mathrm{var} [|\mathbf w|^2] &= \mathbb E\left[\left(\sum_{i=1}^D w_i^2\right)^2\right] - \mathbb E\left[\sum_{i=1}^D w_i^2\right]^2 \\ &= \sum_{i, j=1}^D \mathbb E[w_i^2 w_j^2] - (D\sigma^2)^2 \\ &= \sum_{i \neq j}^D \mathbb E[w_i^2] \mathbb E[w_j^2] + \sum_{i=1}^D \mathbb E[w_i^4]- (D\sigma^2)^2 \\ &= D(D-1) \sigma^4 + D(3\sigma^4) - (D\sigma^2)^2 \\ &= 2D\sigma^4. \end{align}

So the thickness as a fraction of the radius is

δDσ=2DσD=2σ=2Din(l),\frac{\delta}{\sqrt{D}\sigma} = \frac{\sqrt{2D}\sigma}{\sqrt{D}} = \sqrt{2}\sigma = \frac{2}{\sqrt{D_\mathrm{in}^{(l)}}},

where the last equality follows from the choice of σ\sigma for Kaiming initialization.

This means that for suitably wide networks (Din(l)D_\mathrm{in}^{(l)} \to \infty), the thickness of this shell goes to 00.

Taking the thickness to 0

This suggests an alternative initialization strategy: sample directly from the boundary of a hypersphere with radius Dσ\sqrt{D}\sigma, i.e., modify the shell thickness to be 00.

This can easily be done by sampling each component from a normal distribution with mean 0 and variance 11 and then normalizing the resulting vector to have length Dσ\sqrt{D}\sigma (this is known as the Muller method).

Perturbing Weight initialization

Naïvely, if we're interested in a perturbation analysis of the choice of weight initialization, we prepare some baseline initialization, w0\mathbf w_0, and then apply i.i.d. Gaussian noise, δ\boldsymbol \delta, to each of its elements, δiN(0,ϵ2)\delta_i \sim \mathcal N(0, \epsilon^2).

The problem with this is that the perturbed weights w=w0+δ\mathbf w = \mathbf w_0 + \boldsymbol\delta are no longer sampled from the same distribution as the baseline weights. In terms of the geometric picture from the previous section, we're increasing the thickness of the hyperspherical shell in the vicinity of the baseline weights.

There is nothing wrong with this per se, but it introduces a possible confounder (the thickness).

The modification we made to Kaiming initialization was to sample directly from the boundary of a hypersphere, rather than from a hyperspherical shell. This is a more natural choice when conducting a perturbation analysis, because it makes it easier to ensure that the perturbed weights are sampled from the same distribution as the baseline weights.

Geometrically, the intersection of a hypersphere SDS^D of radius w0=w0w_0=|\mathbf w_0| with a hypersphere SDS^D of radius ϵ\epsilon that is centered at some point on the boundary of the first hypersphere, is a lower-dimensional hypersphere SD1S^{D-1} of a modified radius ϵ\epsilon'. If we sample uniformly from this lower-dimensional hypersphere, then the resulting points will follow the same distribution over the original hypersphere.

This suggests a procedure to sample from the intersection of the weight initialization hypersphere and the perturbation hypersphere.

First, we sample from a hypersphere of dimension D1D-1 and radius ϵ\epsilon' (using the same technique we used to sample the baseline weights). From a bit of trigonometry, see figure below, we know that the radius of this hypersphere will be ϵ=w0cosθ\epsilon' = w_0\cos \theta, where θ=cos1(1ϵ22w02)\theta = \cos^{-1}\left(1-\frac{\epsilon^2}{2w_0^2}\right).

400

Next, we rotate the vector so it is orthogonal to the baseline vector w0\mathbf w_0. This is done with a Householder reflection, HH, that maps the current normal vector n^=(0,,0,1)\hat{\mathbf n} = (0, \dots, 0, 1) onto w0\mathbf w_0:

H=I2ccTcTc,H = \mathbf I - 2\frac{\mathbf c \mathbf c^T}{\mathbf c^T \mathbf c},

where

c=n^+w^0,\mathbf c = \hat{\mathbf n} + \hat {\mathbf w}_0,

and w^0=w0w0\hat{\mathbf w}_0 = \frac{\mathbf w_0}{|w_0|} is the unit vector in the direction of the baseline weights.

Implementation note: For the sake of tractability, we directly apply the reflection via:

Hy=y2cTycTcc.H\mathbf y = \mathbf y - 2 \frac{\mathbf c^T \mathbf y}{\mathbf c^T\mathbf c} \mathbf c.

Finally, we translate the rotated intersection sphere along the baseline vector, so that its boundary goes through the intersection of the two hyperspheres. From the figure above, we find that the translation has the magnitude w0=w0cosθw_0' = w_0 \cos \theta.

By the uniform sampling of the intersection sphere and the uniform sampling of the baseline vector, we know that the resulting perturbed vector will have the same distribution as the baseline vector, when restricted to the intersection sphere.

sampling-perturation 1.png

Dynamical Systems

Formally, a dynamical system is a tuple (T,M,Φ)(\mathcal T, \mathcal M, \Phi), where T\mathcal T is the "time domain" (some monoid), M\mathcal M is the phase space over which the evolution takes place (some manifold), and Φ\Phi is the evolution function, a map,

Φ:T×MM,\Phi: \mathcal T \times \mathcal M \to \mathcal M,

that satisfies, xM,t1,t2T\forall x \in \mathcal M, \forall t_{1}, t_{2}\in \mathcal T,

Φ(0,x)=x,Φ(t2,Φ(t1,x))=Φ(t2+t1,x).\begin{align} \Phi(0, x) &= x,\\ \Phi(t_{2}, \Phi(t_{1}, x)) &= \Phi(t_{2}+t_{1}, x). \end{align}

4

Informally, we're usually interested in one of two perspectives:

  1. Trajectories of individual points in M\mathcal M, or
  2. Evolution of probability densities over M\mathcal M.

The former is described in terms of differential equations (for continuous systems) or difference equations (for discrete systems), i.e.,

xt=ϕ(x,t)\frac{\partial x}{\partial t} = \phi(x, t)

or

xt+1xt=ψ(xt,t)x_{t+1}- x_{t} = \psi(x_{t}, t)

The latter is described in terms of a transfer operator:

tp(x)=Lp(x).\nabla_{t} p(x) = \mathcal L p(x).

Both treatments have their advantages: the particle view admits easier empirical analysis (we just simulate a trajectory), while the latter

In terms of the former, path-sensitivity becomes the question of chaos: do nearby

Lyapunov spectrum

The Lyapunov exponent λ\lambda quantifies the rate of separation of infinitesimally close trajectories:

δZ(t)eλtδZ0. |\delta \mathbf{Z}(t)| \approx e^{\lambda t}\left|\delta \mathbf{Z}_0\right|.

If the exponent is positive, then nearby trajectories will diverge, and the dynamics are chaotic. If the exponent is negative, then nearby trajectories will converge

For a DD-dimensional system, there are DD Lyapunov exponents. We often focus exclusively on the maximal Lyapunov exponent because it dominates the overall rate of separation between two neighboring trajectories. However, the full spectrum contains valuable additional information like the rate of entropy production, the fractal dimension, the Hausdorff dimension, and the Lyapunov dimension.

The first major limitation is that the full Lyapunov spectrum is intractable for large DD. The other major limitation is that the Lyapunov spectrum requires a suitable norm. We can use the l2 norm in W\mathcal W, but as we already saws what we really care about is measuring distance in F\mathcal F with some metric d:F×FRd : \mathcal F \times \mathcal F \to \mathbb R.

One option would be to repurpose the loss function \ell:

d(f1,f2)=1Ni=1N(f1(xi),f2(xi)), d(f_1, f_2) = \frac{1}{N} \sum_{i=1}^N \ell(f_1(x_i), f_2(x_i)),

Here, we treat f2(xi)f_2(x_i) as the truth, and use the empirical risk as our distance metric. We could define an analogous distance over either the train or test set. As long as \ell is a suitable metric (or easily converted into a metric through, e.g., symmetrization), dd will be too.

This suffers a major downside that two functions will be seen as identical as long as they have the same performance on the dataset in question. They can have totally different performance on unseen examples. Devising suitable distance metrics for neural networks requires more work.

Pasted image 20221213224220.png

Autocorrelations

The autocorrelation of a random process {Xt}\{X_t\} is the correlation between two values of that process at different times,

RXX(t1,t2)=Xt1Xt2XX. R_{XX}(t_1, t_2) = \langle X_{t_1} X_{t_2}\rangle_{XX}.

For stationary processes (where Xt\langle X_t\rangle is independent of tt), this becomes a function of one variable, τ=t2t1\tau=t_2-t_1,

RXX(τ)=XtXt+τXX. R_{XX}(\tau) = \langle X_tX_{t+\tau}\rangle_{XX}.

In the case of training, the dynamics aren't stationary: learning rate schedules and the the "descent" of gradient descent ensures that these correlations will depend on our choice of starting time. However, we're not typically interested in correlations between later time steps. We care about autocorrelations relative to the starting point t1=0t_1=0.

In practice, the probabilistic and point-wise views are intimately connected: the relevant autocorrelation timescales can often be directly related to the maximal Lyapunov component. These two formulas are only a small sample of the available tooling.

F\mathcal F is the space of pp-integrable functions, fwFf_w \in \mathcal F iff

Xfw(x)pdx<,\int_{\mathcal X} |f_{w}(x)|^{p}\, \mathrm dx < \infty,

which is equipped with a metric,

dF(f,g)=Xf(x)g(x)pdx.d_{\mathcal F}(f, g) = \int_{\mathcal X}|f(x)-g(x)|^{p }\, \mathrm d x.

So we exchange evolution over a finite-dimensional W\mathcal W with an infinite-dimensional

Footnotes

  1. For regression, Y=RN\mathcal Y = \mathbb R^N, for classification, YN\mathcal Y \subset \mathbb N, for self-supervised tasks, we're often interested in Y=X\mathcal Y = \mathcal X, etc. 2

  2. For convenience, we'll start by ignoring model hyperparameters. You can view this as either subsuming the hyperparameters into the weights, into the optimizer, or into our choice of ff. Later, it'll be useful to separate out hyperparameters (i.e., f:X×Y×HYf: \mathcal X \times \mathcal Y \times \mathcal H \to \mathcal Y).

  3. A quick note on notation: Often, you'll see the dataset denoted with a capital DD. Here, we're using a lowercase because it'll be useful to treat d\mathbf d as an instance of a random variable D\mathbf D. Similarly for xx, yy, and ww (XX, YY, and WW). Even though xx, yy, and ww are (typically) all vectors, we'll reserve boldface for sets. TODO: Maybe just bite the bullet and bold all of them.

  4. It's possible to generalize this further so that Φ:U(T×M)M\Phi : U \subseteq (\mathcal T \times \mathcal M) \to \mathcal M, but this won't be necessary for us.

Review of Generalization

NOTE: Work in Progress

Outline

  • Warm-up: MSE & bias-variance trade-offs.
  • Generalization = Simplicity.
    • Solomonoff induction,
    • Complexity measures
      • BIC/MDL, AIC
      • VC dimension & structural risk minimization
      • Rademacher complexity
  • Early Stopping
    • NTK
  • Problems:
    • Vacuous bounds (Zhang et al. 2017). Always zero training error. Any optimizer (even stupid ones). Crazily convex loss spots. Even with regularization
  • From functions to probabilities
  • Mingard-style “Bayesian inference”

Misc:

  • SGD & Basin broadness
  • Heavy-tailed random matrix theory (implicit self-regularization)
  • Double descent & grokking
  • Recent Anthropic work It’s non-obvious that (l2) regularization chooses simpler functions in NNs. (Polynomial regression is one thing).

"Regular" learning theory (RLT) predicts that overparametrized deep neural networks (DNNs) should overfit. In practice, the opposite happens: deeper networks perform better and generalize further. What's going on?

The consensus is still out, but different strands of research appear to be converging on a common answer: DNNs don't overfit because they're not actually overparametrized; SGD and NN's symmetries favor "simpler" models whose effective parameter counts are much lower than the externally observed parameter count. Generality comes from a kind of internal model selection where neural networks throw out unnecessary expressivity.

In this post, I'll make that statement more precise. We'll start by reviewing what RLT has to say about generalization and why it fails to explain generalization in DNNs. We'll close with a survey of more recent attempts at the learning theory of DNNs with particular emphasis on "singular" learning theory (SLT).

Intro to Statistical Learning Theory

Before we proceed, we need to establish a few definitions. Statistical learning theory is often sloppy with its notation and assumptions, so we have to start at the basics.

In this post, we'll focus on a (self-)supervised learning task. Given an input xXx \in \mathcal X, we'd like to predict the corresponding output yYy \in \mathcal Y. This relationship is determined by an unknown "true" pdf q(z)=q(x,y)q(z) = q(x, y) over the joint sample space Z=X×Y\mathcal Z = \mathcal X \times \mathcal Y.

If we're feeling ambitious, we might directly approximate the true joint distribution with a generative model p(x,yw)p(x, y|w) that is parametrized by weights wWRDw \in \mathcal W \subseteq \mathbb R^D. If it's tractable, we can marginalize this joint distribution to obtain the marginal and conditional distributions. 1

If we're feeling less ambitious, we might go ahead with a discriminative model of the conditional distribution, p(yx,w)p(y| x, w). (In the case of regression and classification, we're interested in inferring target values and labels from future samples, so we may not care all that much about q(x)q(x) and q(xy)q(x|y).)

And if we're feeling unambitious (as most machine learning theorists are), it may be enough to do direct inference with a discriminant function — i.e., to find a deterministic function that maps inputs to outputs, fw:XYf_w: \mathcal X \to \mathcal Y.2

In the case of regression, this is justified on the (only occasionally reasonable) assumption that the error is i.i.d. sampled from a Gaussian distribution. In other words, the discriminative model is defined in terms of ff as:

p(yx,w)=N(yfw(x),ϵ).p(y|x, w) = \mathcal N(y | f_{w}(x), \epsilon).

Machine learning usually begins with discriminant functions. Not only is this often easier to implement, but it offers some additional flexibility in the kinds of functions we can model because we optimize arbitrary loss functions that are unmoored from Bayesian-grounded likelihoods or posteriors.3

This flexibility comes at a cost: without the theoretical backing, it's more difficult to reason systematically about what's actually going on. The loss functions and regularizers of machine learning often end up feeling ad hoc and unsupported because they are. Or at least they will be until section 3.

As a result, we'll find it most useful to work at the intermediate level of analysis, in terms of discriminant models.

What Is Learning, Anyway?

The strength of the Bayesian framing is that for some prior over weights, ϕ(w)\phi(w), and a dataset d(n)Zn\mathbf d^{(n)} \in \mathcal Z^n of samples {(xi,yi)}i=1n\{(x_i, y_i)\}_{i=1}^n, we can "reverse" the likelihood (assuming each sample is sampled i.i.d. from p(x,y)p(x, y)),

p(d(n)w)=i=1np(xi,yiw),p(\mathbf d^{(n)}|w) = \prod_{i=1}^{n}p(x_{i}, y_{i}|w),

to obtain the a posteriori distribution,

p(wd(n))=1Znp(d(n)w)ϕ(w),p(w|\mathbf d^{(n)}) = \frac{1}{Z_{n}}p(\mathbf d^{(n)}|w)\phi(w),

where ZnZ_n is the evidence (or partition function),

Zn=Wp(d(n)w)ϕ(w)dw.Z_{n}= \int_{\mathcal W} p(\mathbf d^{(n)}|w)\,\phi(w)\,\mathrm dw.

The aim of "learning" is to make our model p(yx,w)p(y|x, w) as "close" as possible to the truth q(yx)q(y|x).

In the probabilistic formulation, the natural choice of "distance" is the Kullback-Leibler divergence (which is not actually a distance):

K(w):=DKL(p(x,y)q(x,yw))=log(q(yx,w)p(yx))X,Y,K(w) := D_\text{KL}\left(p(x, y)\,||\,q(x, y| w)\right) = \left\langle \log \left(\frac{q(y|x, w)}{p(y|x)}\right) \right\rangle_{X, Y},

where the second equality follows from defining q(x,yw)=q(yx,w)p(x)q(x, y|w) = q(y|x, w) \cdot p(x).

Formally, the aim of learning is to solve the following:

w=argminwK(w)w^* = \underset{w}{\operatorname{argmin}} K(w)

We say p(yx)p(y|x) is realizable iff K(w)=0K(w^*) = 0, in which case we call ww^* the "true" parameters.

At this point, we run into a bit of a problem: we can't actually compute the expectations X,Y\langle \cdot\rangle_{X, Y} since we don't know p(x,y)p(x, y) (and even if we did, the integration would likely be intractable). Instead, we have to resort to empirical averages over a dataset,. 3

Maximum Likelihood Estimation is KL-Divergence Minimization

Example: Regression & Mean-Squared Error

Fortunately, there's a more general and principled way to recover discriminant models from discriminant functions than the assumption of isotropic noise, as long as we're willing to relax Bayes' rule.

First, we introduces a parameter β0\beta \geq 0 that controls the tradeoff between prior and likelihood to obtain the tempered Bayes update [1],4

pβ(wd)p(dw)βp(w),p_\beta(w|\mathbf d) \propto p(\mathbf d| w)^{\beta}p(w),

where d\mathbf d is a dataset of samples {(xi,yi)}i=1n\{(x_{i},y_{i})\}_{i=1}^n, and the likelihood is obtained by assuming each sample is distributed i.i.d. according to p(xi,yiw)p(x_{i}, y_{i}|w) s.t. p(dw)=i=1np(xi,yiw)p(\mathbf d|w) = \prod_{i=1}^{n}p(x_{i}, y_{i}|w).

Next, we replace the tempered likelihood with a general loss function

The real-world is noisy, and pure function approximation can't account for this.

To remedy this, it's common to model any departure from our deterministic predictions, y^=fw(x)\hat y = f_w(x), as isotropic noise, e.g., q(yx,w)=N(yfw(x),σ2)q(y|x, w) = \mathcal N(y| f_w(x), \sigma^2).

Physically, we may expect the underlying process to be deterministic (given by a "true" function ff^*) with noise creeping in through some unbiased measurement error ϵN(0,σ2)\epsilon \sim \mathcal N(0, \sigma^2).

For Gaussian noise, the NLL works out to:

Ld(w)=12σ2Ni=1N(yiy^i)2+const.L_\mathbf d(w) = \frac{1}{2\sigma^2N}\sum_{i=1}^N (y_i-\hat y_i)^2 + \text{const}.

Assuming isotropic Gaussian noise in a regression setting, we see that MLE simplifies to minimizing the mean squared error (up to some overall constant and scaling).

To make MLE more Bayesian, you just multiply the likelihood by some prior over the weights, to obtain maximum a posteriori estimation (MAP): p(wd)p(dw)ϕ(w).p(w|\mathbf d) \propto p(\mathbf d|w)\phi(w). If we enforce a gaussian prior with precision α\alpha, the loss gains a regularization term (weight decay), α2wTw\frac{\alpha}{2}w^T w to the loss function. [2]

Gibbs Generalization Error

Todo

What Is Generalization, Anyway?

Consider the empirical probability distribution,

q(zd)=1Nzidδ(zzi), q(z|\mathbf d) = \frac{1}{N}\sum_{z_i \in \mathbf d} \delta(z-z_i),

where δ()\delta(\cdot) is a delta function appropriate to the sample space.

Although this will converge to the true distribution as NN\to\infty, for any finite NN, we encounter sampling error. This problem is particularly pronounced for unseen samples: qD(z)q_D(z) may (and almost always does) assign zero probability to samples that have non-zero probability under the true distribution.

The fundamental challenge of generalization is to make predictions about these unseen samples. Let's make this more precise.

Given some loss function :Y×YR\ell: Y \times Y \to \mathbb R, we're interested in how \ell performs on future samples, as measured by the expected risk or generalization error,

I : FRf(f(x),y)X,Y.\begin{align} I\ :\ \mathcal F &\to \mathbb R \\ f &\mapsto \left\langle \ell(f(x), y)\right\rangle_{X,Y}.\\ \end{align}

However, we can only estimate this performance on on the available dataset via the empirical risk,

Id : FRf(f(x),y),\begin{align} I_\mathbf d\ :\ \mathcal F &\to \mathbb R \\ f &\mapsto \overline{\ell(f(x), y)},\\ \end{align}

where \overline \cdot denotes an empirical average over the dataset, d\mathbf d.

Our problem is that the true generalization error is unknowable. So, in practice, we split our dataset into a training set, s\mathbf s, and test set, e\mathbf e, respectively. We learn our parameters ww via the training set, and then estimate the generalization error on the test set. That's where we encounter our first major confusion.3

You see, besides the generalization error, there's the distinct notion of the generalization gap,

ϵg=I[fwS=s]I[fwE=e]. \epsilon_g = I[f_w|\mathbf S = \mathbf s] - I[f_w|\mathbf E = \mathbf e].

Whereas generalization error is about absolute performance on the true distribution, the generalization gap is about relative performance between the training and test sets. Generalization error combines both "performance" and "transferability" into one number, while the generalization gap is more independent of "performance" — past a certain threshold, any amount of test-set performance is compatible with both a low and high generalization gap.

For the remainder of this post, I'll focus on the generalization gap. It's not absolute performance we're interested in: universal approximation theorems tell us that we should expect excellent training performance for neural networks. The less obvious claim is why this performance should transfer so well to novel samples.


Yes, more data means better generalization, but that's not what we're talking about.

In the limit NN\to \infty, the empirical risk converges to expected risk, and both generalization error and gap will fade away: ϵg0,ϵg20asN.\epsilon_g \to 0,\quad \epsilon_g^2 \to 0 \quad \text{as} \quad N \to \infty. So obviously the larger datasets that have accompanied the deep learning boom explain some of the improvement in generalization.

PAC Learning

Established by Valiant (1984) [3], Probably Approximately Correct (PAC) learning establishes upper bounds for the risk I[f]I[f]. In its simplest form, it states that, for any choice of ϵ\epsilon ("probably") a predictor, ff, will have its risk bounded by some δ\delta ("approximately correct"):

P[I[f]δ]1ϵ.\mathbb{P}[I[f] \leq \delta] \geq 1-\epsilon.

Originally, PAC learning included the additional assumption that the predictor was polynomial in NN and 1/ϵ1/\epsilon, but this has been relaxed to refer to any bound holding with high probability.

Still, it's not a full explanation as typical networks can easily memorize the entire training set (even under random labelings [4]).


MSE & Bias-variance tradeoff

In the case of isotropic noise (where our loss function is the MSE), the generalization error of a model is:

Generalization error=(yy^)2X,Y=Bias2+Variance+Irreducible error.\text{Generalization error} = \left\langle(y - \hat y)^2 \right\rangle_{X,Y} = \text{Bias}^2 + \text{Variance} + \text{Irreducible error}.

In other words, it is the (l2) distance between our predictions and the truth averaged over the true variables X×YX \times Y. The bias-variance decomposition splits the generalization error into a bias term,

Bias=fw(x)yX,Y,\text{Bias} = \langle f_w(x)-y\rangle_{X,Y},

which measures the average difference between predictions and true values, a variance term,

Variance=(fw(x)X,Yfw(x))2X,Y,\text{Variance} = \left\langle \left(\langle f_w(x)\rangle_{X,Y} - f_w(x)\right)^2\right\rangle_{X, Y},

which measures the spread of a model's predictions for a given point in the input space, and an irreducible error due to the inherent noise in the data. (For Gaussian noise, the irreducible error is σ2\sigma^2.)

In order to achieve good generalization performance, a model must have low bias and low variance. This means that the model must be complex enough to capture the underlying patterns in the data, but not so complex that it overfits the data and becomes sensitive to noise. In practice, these two forces are at odds, hence "tradeoff."


Generalization is about simplicity

A common assumption across the entire statistical learning theory literature is that of simple functions generalizing better. It's worth spending a second to understand why we should expect this assumption to hold.

Occam's Razor

"Entities must not be multiplied beyond necessity".

Occam's razor is the principle of using the simplest explanation that fits the available data. If you're reading this, you probably already take it for granted.

Solomonoff Induction

Algorithmic complexity theory lets us define a probability distribution over computable numbers,

p(x)=p2K(p),p(x)= \sum_{p}2^{-K(p)}, where K()K(\cdot) is the (prefix) Kolmogorov complexity. The probability of any given number is the probability that a uniform random input tape on some Universal Turing Machine outputs that number. Unfortunately, the halting problem makes this just a tad uncomputable. Still, in principle, this allows a hypercomputer to reason from the ground up what the next entry in a given sequence of numbers should be.

Bayesian/Akaike Information Criterion

One of the main strengths of the Bayesian frame is that it lets enforce a prior φ(w)\varphi(w) over the weights, which you can integrate out to derive a parameter-free model:

p(yx)=wp(yx,w)φ(w) dw.p(y|x) = \int_w p(y|x, w)\varphi(w)\ \text{d}w.

One of the main weaknesses is that this integral is often almost always intractable. So Bayesians make a concession to the frequentists with a much more tractable Laplace approximation (i.e., you approximate your model as quadratic/gaussian in the vicinity of the maximum likelihood estimator (MLE), w(0)w^{(0)}):

K(w)12(ww(0))TI(w(0))(ww(0)), K(w) \approx \frac{1}{2}(w-w^{(0)})^T I(w^{(0)}) (w-w^{(0)}),

where I(w)I(w) is the Fisher information matrix:

Ij,k(w)=(wjlogp(yx,w))(wklogp(yx,w))X,Yw=w.I_{j,k}(w)=\left\langle\left(\frac{\partial}{\partial w_j} \log p(y|x, w)\right)\left(\frac{\partial}{\partial w_k} \log p(y |x, w)\right) \right\rangle_{X,Y|w = w}.

From this approximation, a bit more math gives us the Bayesian information criterion (BIC):

BIC=Ln(w0)+D2logn.\text{BIC} = L_n(w_0) + \frac{D}{2}\log n.

The BIC (like the related Akaike information criterion) is a criterion for model selection that penalizes complexity. Given two models, the one with the lower BIC tends to overfit less (/"generalize better").

That is, simpler models (with fewer parameters) are more likely in approximate Bayesian inference. We'll see in the section on singular learning theory that the BIC is unsuitable for deep neural networks, but a generalized version, the Widely Applicable Bayesian Information Criterion (WBIC) will pick up the slack.

Minimum Description Length

The BIC is formally equivalent to MDL. TODO

Maximum Entropy Modeling

TODO

Classic Learning Theory

As we've seen in the previous section, the question of how to generalize reduces to the question of how to find simple models that match the data.

In classical learning theory, the answer to this is easy: use fewer parameters, include a regularization term, and enforce early stopping. Unfortunately, none of these straightforwardly help us out with deep neural networks.

For one, double descent tells us that the relation between parameter count and generalization is non-monotonic: past a certain number of parameters, generalization error will start decreasing. Moreover, a large number of the complexity measures we've seen are extensive (they scale with model size). That suggests they're not suited to the task.

As for regularization, it gets more complicated. That regularization on polynomial regression leads to simpler functions is straightforward. That regularization on deep neural networks leads to simpler models is less obvious. Sure, a sparse l1 regularizer that pushes weights to zero seems like it would select for simpler models. But a non-sparse l2 regularizer? Linking small parameters to simple functions will require more work. A deeper problem is that DNNs can memorize randomly labeled data even with regularizers [5]. Why then, do DNNs behave so differently on correctly labeled (/well-structured) data

Other forms of regularization like dropout are more understandable: they explicitly select for redundancy which collapses the effect parameter count.

Finally, early stopping doesn't seem to work with the observation of grokking: in certain models, training loss and test loss may plateau (at a poor generalization gap) for many epochs before training loss suddenly sharply decreasing (towards a much better generalization gap).

SGD Favors Flat Minima

  • Pβ(f)P_\beta(f) is the probability that MM expresses ff on DD upon a randomly sampled parametrization. This is our "prior"; it's what our network expresses on initialization.
  • Vβ(f)V_\beta(f) is a volume with Gaussian measure that equals Pβ(f)P_\beta(f) under Gaussian sampling of network parameters.
    • This is a bit confusing. We're not talking about a continuous region of parameter space, but a bunch of variously distributed points and lower-dimensional manifolds. Mingard never explicitly points out why we expect a contiguous volume. That or maybe it's not necessary for it to be contiguous
  • β\beta denotes the "Bayesian prior"
  • Popt(fS)P_\text{opt}(f|S) is the probability of finding ff on EE under a stochastic optimizer like SGD trained to 100% accuracy on SS.
  • Pβ(fS)=P(Sf)Pβ(f)Pβ(S),P_\beta(f|S) = \frac{P(S|f) P_\beta(f)}{P_\beta(S)}, is the probability of finding ff on EE upon randomly sampling parameters from i.i.d. Gaussians to get 100% accuracy on SS.
    • This is what Mingard et al. call "Bayesian inference"
    • P(Sf)=1P(S|f)=1 if ff is consistent with SS and 00 otherwise I.e.: Hessians with small eigenvalues.

Flat minima seem to be linked with generalization:

  • Hochreiter and Schmidhuber, 1997a; Keskar et al., 2016; Jastrzebski et al., 2018; Wu et al., 2017; Zhang et al., 2018; Wei and Schwab, 2019; Dinh et al., 2017.

SGD

Problems:

  • Suitable reparametrizations can change flatness without changing computation.

Initialization Favors "simple" functions

Mingard 2021

  • Levin et al. tell us that many real-world maps satisfy P(f)2bK~(f)+aP(f) \lesssim 2^{-b\tilde K(f)+a}, where K^\hat K is a computable approximation of the true Kolmogorov complexity K(f)K(f).
  • Empirically, maps of the form f:{0,1}n{0,1}f:\{0, 1\}^n \to \{0, 1\} satisfy a similar upper bound using a computable complexity measure (CSR).
  • This is an upper bound, not more than that!
  • The initialization acts as our prior in Bayesian inference.

SGD Performs Hidden Regularization

Mingard 2020

  1. **Bayesian inference preserves the "simplicity" of the prior.
  2. SGD performs a kind of "Bayesian inference"
    • You can approximate Pβ(f)P_\beta(f) with Gaussian Processes.
    • Mingard's main result is that Popt(fS)Pβ(fS)P_\text{opt}(f|S) \approx P_\beta(f|S) appears to hold for many datasets (MNIST, Fashion-MNIST, IMDb movie review, ionosphere), architectures (Fully connected, Convolutional, LSTM),  optimizers (SGD, Adam, etc.), training schemes (including overtraining) and optimizer hyperparameters (e.g. batch size, learning rate).
    • Optimizer hyperparameters matter much less than Pβ(fS)P_\beta(f|S)

200 200

Singular Learning theory

Glossary

  • ff^* is the true function
  • ff is some implemented function
  • MM is our neural network
  • DD is a dataset of pairs {(xi,yi)}i=1N\{(x_i, y_i)\}_{i=1}^N
  • S,EDS, E \subset D are the training & test sets, respectively.

Learnability

  • PAC in original formulation: simpler functions are easier to learn (polynomial time)

Usually, the true relation is probabilistic. In this case, we're not interested in a deterministic mapping ff^* from elements of X\mathcal X to elements Y\mathcal Y, but a probability distribution, q(x,y)q(x, y), which relates random variables XX and YY (where X\mathcal X and Y\mathcal Y are the associated sample spaces).

We don't have direct access to q(x,y)q(x, y), but we do have access to a dataset of samples, D={(Xi,Yi)}i=1n\mathbf D = \{(X_i, Y_i)\} _ {i=1}^n, which is itself a random variable. We specify some loss function, \ell, Y×YR\mathcal Y \times \mathcal Y \to\mathbb R, which maps a prediction, y^=fw(x)\hat y = f_w(x), and true value, y=f(x)y = f^*(x), to a "loss". Assuming,

Usually, the true relation is probabilistic. In this case, we're not interested in a deterministic relation ff^*, but a probabilistic ground truth, given by some distribution, q(x,y)q(x, y).

Formally, we're trying to find the weights, w^\widehat w, that minimize the expected risk ("generalization error") for some choice of loss function, :Y×YR\ell: \mathcal Y \times \mathcal Y \to \mathbb R,

w^:=argminwWEX,Y[(fw(X),Y)]. \widehat w := \underset{w \in \mathcal W}{\operatorname{argmin}} \mathbb E_{X,Y}[ \ell(f_w(X), Y)].

We don't have direct access to the probability distribution, P(X,Y)P(X, Y), generating our data, so we approximate the expectation with an empirical average over a dataset D={(Xi,Yi)}i=1n\mathbf D = \{(X_i, Y_i)\} _ {i=1}^n to get the empirical risk or "test loss". 2 More accurately, we optimize w^\widehat w on a training set, Dtrain\mathbf D_\text{train}, but report performance on a test set, Dtest\mathbf D_\text{test}, to avoid sampling bias and overfitting, where D=DtrainDtest\mathbf D = \mathbf D_\mathrm{train} \cup \mathbf D_\mathrm{test} and DtrainDtest=\mathbf D_\mathrm{train} \cap \mathbf D_\mathrm{test} = \emptyset. 3

Footnotes

  1. Equivalently, we could separately model the likelihood p(xy)p(x|y) and prior p(y)p(y), then multiply to get the joint distribution. P.S. I prefer using qq and pp to mean the model and truth, respectively, but I'm keeping to the notation of Watanabe in Algebraic Geometry and Singular Learning Theory.

  2. Traditionally, statistical learning theory drops the explicit dependence on ww, and instead looks at model selection at the level of fFf \in \mathcal F, qQq \in \mathcal Q. As we're interested in neural networks, we'll find it useful to fix a particular functional form of our model and look at selection of suitable parameters in ww. Later, we'll see that the understanding the mapping wfww \to f_w is at the heart of understanding why deep neural networks work as well as they do. 2

  3. From Guedj [1]: "The past few decades have thus seen an increasing gap between the Bayesian statistical literature, and the machine learning community embracing the Bayesian paradigm – for which the Bayesian probabilistic model was too much of a constraint and had to be toned down in its influence over the learning mechanism. This movement gave rise to a series of works which laid down the extensions of Bayesian learning[.]" 2 3 4

  4. . TODO: Something something Safe Bayes

Scaling Laws of Path dependence

NOTE: Work in Progress

How much do the models we train depend on the path they follow through weight space?

Should we expect to always get the same models for a given choice of hyperparameters and dataset? Or do the outcomes depend highly on quirks of the training process, such as the weight initialization and batch schedule?

If models are highly path-dependent, it could make alignment harder: we'd have to keep a closer eye on our models during training for chance forays into deception. Or it could make alignment easier: if alignment is unlikely by default, then increasing the variance in outcomes increases our odds of success.

Vivek Hebbar and Evan Hubinger have already explored the implications of path-dependence for alignment here. In this post, I'd like to formalize "path-dependence" in the language of dynamical systems, and share the results of experiments inspired by this framing.

From these experiments, we find novel scaling laws relating choices of hyperparameters like model size, learning rate, and momentum to the trajectories of these models during training. We find functional forms that provide a good empirical fit to these trajectories across a wide range of hyperparameter choices, datasets, and architectures.

Finally, we study a set of toy models that match the observed scaling laws, which gives us insight into the possible mechanistic origins of these trends, and a direction towards building a theory of the dynamics of training.

Formalizing Path-dependence

Hebbar and Hubinger define path-dependence "as the sensitivity of a model's behavior to the details of the training process and training dynamics." Let's make that definition more precise. To start, let us formalize the learning process as a dynamical system over weight space.

In the context of machine learning, our goal is (usually) to model some "true" function f:XYf^* : \mathcal X \to \mathcal Y1. To do this, we construct a neural network, f:X×WYf: \mathcal X \times \mathcal W \to \mathcal Y, which induces a model of ff^* upon fixing a choice of weights, wWRDw \in \mathcal W \subset \mathbb R^D. For convenience, we'll denote the resulting model with a subscript, i.e., Ffw(x):=f(x,w)\mathcal F \ni f_{w}(x):= f(x, w).2

To find the weights, woptw_\mathrm{opt}, such that fopt:=fwoptf_\text{opt} := f_{w_\text{opt}} is as "close" as possible to ff^*, we specify a loss function LnL_n, which maps a choice of parameters and a set of training examples, dtrain={(xi,yi)}i=1nZn\mathbf d_\text{train} = \{(x_{i},y_i)\}_{i=1}^n \in \mathcal Z^n, to a scalar "loss"3:

Ln:W×ZnR.L_n: \mathcal W \times \mathcal Z^n \to \mathbb R.

For a fixed choice of dataset, we denote the empirical loss, Ltrain(w)=L(w,dtrain)L_\text{train}(w) =L(w, \mathbf d_\text{train}). Analogously, we can define a test loss, LtestL_\text{test} over a corresponding test set, and batch loss, Lb(t)(w)L^{(t)}_\mathbf{b}(w), for a batch b(t)dtrain\mathbf b^{(t)} \subset \mathbf d_\text{train}.

Next, we choose an optimizer, Φ\Phi, which iteratively updates the weights (according to some variant of SGD),

Φ:W×B×HW(w(t), b(t), h)w(t+1). \begin{align} \Phi:\mathcal W \times \mathcal B\times \mathcal H &\to \mathcal W\\ (w^{(t)},\ b^{(t)},\ h) &\mapsto w^{(t+1)}. \end{align}

The optimizer depends on some hyperparameters, hHh \in \mathcal H, and learning schedule b={b(t)}t=1T\mathbf b = \{\mathbf b^{(t)}\}_{t=1}^T of TT batches, b(t)B\mathbf b^{(t)} \in \boldsymbol{\mathcal B}. Note: this learning schedule covers all NepochsN_\text{epochs} epochs. In other words, if an epoch consists of TepochsT_\text{epochs} batches, then T/Tepochs=NepochsT/ T_\text{epochs} = N_\text{epochs}. Within any given epoch, the batches are disjoint, but across epochs, repetition is allowed (though typically no batch will be repeated sample for sample).

If we take the learning schedule and hyperparameters as constant, we obtain a discrete-time dynamical system over parameter space, which we denote Φb,h:WW\Phi_{\mathbf b, h}: \mathcal W \to \mathcal W.

Two Kinds of Path Dependence

Given this dynamical system, Φb,h\Phi_{\mathbf b, h} over weight space, let us contrast two kinds of path dependence:

  1. Global path dependence. Given some distribution over starting weights, p(w(0))p(w^{(0)}), hyperparameters, p(h)p(h), or batches, p(b)p(\mathbf b), what is the resulting distribution over final weights, p(w(T))p(w^{(T)})?
  2. Local path dependence. For some choice of initial weights, w0w^{0}, hyperparameters, hh, or batches, b\mathbf b, and a small perturbation (of size ϵ1\epsilon \ll 1) to one of these values, how "different" does the perturbed model end up being from the baseline unperturbed model? In the limit of infinitesimal perturbations, this reduces to computing derivatives of the kind w(T)w(0)\frac{\partial w^{(T)}}{\partial{w^{(0)}}}, w(T)h\frac{\partial w^{(T)}}{\partial{h}}, and w(T)b(t)\frac{\partial w^{(T)}}{\partial{b^{(t)}}}.

Though we ultimately want to form global statements of path dependence, the distributions involved are generally intractable (which require us to resort to empirical approximations). In the local view, we rarely care about specific perturbations, so we end up studying the relation between a similarly intractable distribution over perturbations p(δ)p(\delta) and p(w(T))p(w^{(T)}). Still, this is much easier to explore empirically because of its smaller support.

Continuous vs. Discrete Hyperparameters

Below, is a table of the various hyperparameters we encounter in our model, weight initializer, optimizer, and dataloader. Here, we're taking "hyperparameter" in the broadest sense to incorporate even the dataset and choice of model architecture.

Hyperparameters(h)Model (hm)Weight Initializer (hinit.)Optimizer (hΦ)Dataloader (hdl)“type"(FC, ResNet, GPT-N, etc.)“type" (KH normal/uniform, etc.)“type" (SGD, Adam, etc.)d,dtrain,dtestFC:{nhidden(l)}l=1nlayers  seeds (wref, δ)β1,β2,λ,ϵ,shuffle seedResNet:nlayers=18,34,w(0), ϵηTbatch\begin{array}{c c c c} &\text{Hyperparameters}&(\mathbf h)& \\ \hline \text{Model } (\mathbf{h}_{m}) & \text{Weight Initializer }(\mathbf{h}_\text{init.}) & \text{Optimizer }(\mathbf h_\Phi) &\text{Dataloader }(\mathbf h_\text{dl}) \\ \hline \text{``type"\tiny{(FC, ResNet, GPT-N, etc.)}} & \text{``type" \tiny{(KH normal/uniform, etc.)}}& \text{``type" \small{(SGD, Adam, etc.)}} & \mathbf{d}, \mathbf{d}_\text{train}, \mathbf{d}_\text{test} \\ \hookrightarrow \small{\text{FC}:\{n^{(l)}_\text{hidden}\}_{l=1}^{n_\text{layers}}}\ \ \, & \text{seeds } \small{(w_\text{ref},\ \delta)} & \hookrightarrow \small{\beta_{1},\beta_{2}, \lambda, \epsilon, \dots} & \text{shuffle seed} \\ \hookrightarrow \small{\text{ResNet}:\tiny{n_\text{layers}=18, 34, \dots}} & |w^{(0)}|,\ \epsilon & \eta & T_\text{batch} \end{array}

To study local perturbations, we're interested in those hyperparameters that we can vary continuously. E.g.: we can gradually increase the magnitude ϵ\epsilon of the perturbation applied to our baseline model, as well as the momenta, regularization coefficients, and learning rates of the optimizer, but we can't smoothly vary the number of layers in our model or the type of optimizer we're using.

From weights to functions

The problem with studying dynamics over W\mathcal W is that we don't actually care about wWw \in \mathcal W. Instead, we care about the evolution in function space, F\mathcal F, which results from the mapping m:WwfwFm: \mathcal W \ni w \mapsto f_{w}\in \mathcal F.

We care about this evolution because the mapping to function space is non-injective; the internal symmetries of mm ensure that different choices of wWw\in \mathcal W can map to the same function fFf \in \mathcal F.[1] As a result, distance in W\mathcal W can be misleading as to the similarity of the resulting function.

The difficulty with function space is that it is infinite-dimensional, so studying it is that much more intractable than studying weight space. In practice, then, we have to estimate where we are in function space over a finite number of samples of input-output pairs, for which we typically use the training and test sets.

However, even though we have full knowledge of W\mathcal W, we can't resolve exactly where we are in F\mathcal F from any finite set of samples. At best, we can resolve the level set in F\mathcal F with a certain empirical performance (like training/test loss). This means our sampling procedure acts as a second kind of non-injective map from FFobs=(X×Y)n\mathcal F\to \mathcal F_\text{obs} = (\mathcal X \times \mathcal Y)^n.

A note on optimizer state: If we want to be exhaustive, let us recall that optimizers might have some internal state, sSs \in \mathcal S. E.g., for Adam, st=(mt,vt)s_{t}= (m_t, v_t), a running average of the first moment and second moment of the gradients, respectively. Then, the dynamical system over W\mathcal W is more appropriately a dynamical system over the joint space W×S\mathcal W \times \mathcal S. Completing this extension will have to wait for another day; we'll ignore it for now for the sake of simplicity and because, in any case, sts_t, tends to be a deterministic function of b\mathbf b and hh.

\begin{array} &&\text{Metrics}& \\ \hline \mathcal W & &\mathcal F \\ \hline d^{(p)}_{W}(w, w')= ||w-w'||_{p}&&L_{\text{cf.}}(f_{w}, f_{w'}) = \frac{1}{N}\sum\limits_{i=1}^{N} \ell\left(f_{w}(x_i), f_{w'}(x_{i})\right) \\ S_{C}(w, w') = \frac{w \cdot w'}{|w||w'|} \end{array}

Experiments

In section XX, we defined path dependence as about understanding the relation between (P(Bt),P(W0),P(H))(P(B_t), P(W_0), P(H)) and (P(WT),P(FT))(P(W_{T}), P(F_{T})). As mentioned, there are two major challenges:

  1. In general, we have to estimate these distributions empirically, and training neural networks is already computationally expensive, so we're restricted to studying smaller networks and datasets.
  2. The mapping m:WFm:\mathcal W\to \mathcal F is non-straightforward because of the symmetries of mm.

To make it easier for ourselves, let us restrict our attention to the narrower case of studying local path-dependence.

a (small) perturbation, ϵ\epsilon, to one of (bt,w0,h)(b_{t}, w_{0}, h). E.g., we'll study probability densities of the kind p(w0)=N(w0wbaseline,ϵ1)p(\mathbf w_0)=\mathcal N(\mathbf w_{0}|\mathbf w_\text{baseline}, \epsilon \mathbf 1). For discrete variables (like the number of layers, network width, and batch schedule), there's a minimum size we can make ϵ\epsilon,

Within dynamical systems theory, there are two main ways to view dynamical systems:

  1. The trajectory view studies the evolution of points like (w0,w1,,wT)(w_0, w_1, \dots, w_T) and (f1,f2,,fT)(f_1, f_2, \dots, f_T).
  2. The probability view studies the evolution of densities like (p0(w),p1(w),,pT(w))(p_0(w), p_1(w), \dots, p_T(w)) and (p1(f),p2(f),,pT(f))(p_1(f), p_2(f), \dots, p_T(f)).

Both have something to tell us about path dependence:

  1. In the trajectory view, sensitivity of outcomes becomes a question of calculating Lyapunov exponents, the rate at which nearby trajectories diverge/converge.
  2. In the probability view, sensitivity of outcomes becomes a question of measuring autocorrelations, wiwi+τ\langle w_{i} w_{i+\tau}\rangle and fifi+τ\langle f_{i} f_{i+\tau} \rangle.

Tracking Trajectories

The experimental set-up involves taking some baseline model, m0m_0, then applying a Gaussian perturbation to the initial weights with norm ϵ\epsilon, to obtain a set of perturbed models {mi}i=1nmodels\{m_i\} _{i=1}^{n _\text{models}}. We train these models on MNIST (using identical batch schedules for each model).

Over the course of training, we track how these models diverge via several metrics (see next subsection). For each of these, we study how the rate of divergence varies for different choices of hyperparameters: momentum β\beta, weight decay λ\lambda, learning rate η\eta, hidden layer width (for one-hidden-layer models), and number of hidden layers. Moving on from vanilla SGD, we compare adaptive variants like Adam and RMSProp.

TODO: Non-FC models. Actually calculate the rate of divergence (for the short exponential period at the start). Perform these experiments for several different batch schedules. Perturbations in batch schedule. Other optimizers besides vanilla SGD.

Measuring distance

  • dw(p)(mi,m0)d_\mathbf{w}^{(p)}(m_i, m_0): the pp-norm between the weight vectors of each perturbed model and the baseline model. We'll ignore the pp throughout and restrict to the case p=2p=2. This is the most straightforward notion of distance, but it's flawed in that it can be a weak proxy for distance in F\mathcal F because of the internal symmetries of the model.
  • dw(mi,m0)w0\frac{d_{\mathbf w}(m_{i}, m_0)}{|\mathbf w_{0}|}: the relative pp-norm between weight vectors of each perturbed model.
  • Ltrain,test(mi)L_\text{train,test}(m_i): the training/test losses
  • δLtrain, test(mi,m0)\delta L_\text{train, test}(m_{i}, m_0): the training/test loss relative (difference) to the baseline model m0m_0
  • ftrain, test(mi)f_\text{train, test}(m_i), δftrain, test(mi,m0)\delta f_\text{train, test}(m_i, m_0): the training/test set classification accuracy and relative classification accuracy.
  • Ltrain cf., test cf.L_\text{train cf., test cf.} and ftrain cf., test cf.f_\text{train cf., test cf.}: same as above, but we take the predictions of the baseline model as the ground truth (rather than the actual labels in the test set).

A few more metrics to include:

  • dwperm.d^\text{perm.}_{\mathbf w}: the l2 norm after adjusting for permutation differences (as described in Ainsworth et al.)
  • w0wiperm.Ltrain (cf.), test (cf.)(w)dw\int_{\mathbf w_0}^{\mathbf w_{i}^\text{perm.}} L_\text{train (cf.), test (cf.)}(\mathbf w) d\mathbf w. The loss integrated over the linear interpolation between two models' weights after correcting for permutations (as described in Ainsworth et al.)

Tracking Densities

TODO

Results

One-hidden-layer Models

What's remarkable about one-hidden-layer models is how little the model depends on weight initialization: almost all of the variance seems to be explained by the batch schedule. Even for initial perturbations of size ϵ=10\epsilon=101, the models appear to become almost entirely equivalent in function space across the entire training process. In the figure below, you can see that the differences in LtrainL_\text{train} are imperceptible (both for a fixed ϵ\epsilon and across averages over different ϵ\epsilon).

TODO: I haven't checked ϵ/w\epsilon/|\mathbf w|. I'm assuming this is >1 for ϵ=10\epsilon=10 (which is why I find this surprising), but I might be confused about PyTorch's default weight initialization scheme. So take this all with some caution.

400300

400

The rate of growth for dwd_\mathbf{w} has a very short exponential period (shorter than a single epoch), followed by a long linear period (up to 25 epochs, long past when the error has stabilized). In some cases, you'll see a slight bend upwards or downwards. I need more time to test over more perturbations (right now it's 10 perturbed models) to clean this up. Maybe these trends change over longer periods, and with a bit more time I'll test longer training runs and over more perturbations.

300300

Hyperparameters

  • Momentum: (β=0.1,0.5,0.9\beta=0.1, 0.5, 0.9) The dwd_\textbf w curves look slightly more curved upwards/exponential but maybe that's confirmation bias (I expect momentum to make divergence more exponential back when I expected exponential divergence). Small amounts of momentum appear to increase convergence (relative to other models equivalent up to momentum), while large amounts increase divergence (towards . For very small amounts, the effect appears to be negligible. Prediction: the dwd_w curves curve down because in flat basins, momentum slows you down.
  • Learning rate: Same for η=103,102,101\eta=10^{-3}, 10^{-2}, 10^{-1}. (Not too surprising) TODO: Compare directly between different learning rates. Can we see η\eta back in the slope of divergence? Prediction: This should not make much of a difference. I expect that correcting for the learning rate should give nearly identical separation slopes.
  • Weight Decay: TODO for λ=103,102,101\lambda=10^{-3}, 10^{-2}, 10^{-1}. Prediction: dwd_w will shrink (because w|w| will shrink). The normalized dw/wd_w/|w| will curve downwards because we break the ReLU-scaling symmetry which means the volume of our basins will shrink
  • Width: TODO. Both this and depth require care. With different sizes of w\mathbf w, we can't naively compare norms. Is it enough to divide by dimw\dim \mathbf w? Prediction: I expect the slope of divergence to increase linearly with the width of a one-layer network (from modeling each connection as separating of its own linear accord).
  • Depth: TODO. Prediction: I expect the shape of dwd_w for each individual layer to become more and more exponential with depth (because the change will be the product of layers that diverge linearly independently). I expect this to dominate the overall divergence of the networks.
  • Architecture: Prediction: I expect convolutional architectures to decrease the rate of separation (after correcting for the number of parameters).
  • Optimizer:
    • Prediction: I expect adaptive techniques to make the rate of divergence exponential.

Datasets

  • Computer Vision
    • MNIST
    • Fashion-MNIST
    • Imagenet: Prediction: I expect Imagenet to lead to the same observations as MNIST (after correcting for model parameter count).
  • Natural Language
    • IMDb movie review database

Why Linear?

I'm not sure. My hunch going in was that I'd see either exponential divergence (indicating chaos) or square root divergence (i.e., Brownian noise). Linear surprises me, and I don't yet know what to make of it.

Maybe all of this is just a fact about one-hidden-layer networks. If each hidden layer evolves independently linearly, maybe this combines additively into something Brownian. Or maybe it's multiplicative (so exponential).

Deep Models

TODO

Appendix

Weight Initialization

For ReLU-based networks, we use a modified version of Kaiming (normal) initialization, which is based on the intuition of Kaiming initialization as sampling weights from a hyperspherical shell of vanishing thickness.

Kaiming initialization is sampling from a hyperspherical shell

Consider a matrix, w(l)\mathbf w^{(l)}, representing the weights of a particular layer ll with shape (Din(l),Dout(l+1))(D_\mathrm{in}^{(l)}, D_\mathrm{out}^{(l+1)}). Din(l)D_\mathrm{in}^{(l)} is also called the fan-in of the layer, and Din(l+1)D_\mathrm{in}^{(l+1)} the fan-out. For ease of presentation, we'll ignore the bias, though the following reasoning applies equally well to the bias.

We're interested in the vectorized form of this matrix, w(l)RD(l)\vec w^{(l)} \in \mathbb R^{D^{(l)}}, where D(l)=Din(l)×Dout(l+1)D^{(l)} =D_\mathrm{in}^{(l)} \times D_\mathrm{out}^{(l+1)}.

In Kaiming initialization, we sample the components, wi(l)w_i^{(l)}, of this vector, i.i.d. from a normal distribution with mean 0 and variance σ2\sigma^2 (where σ2=2Din(l)\sigma^2 = \frac{2}{D_\mathrm{in}^{(l)}}).

Geometrically, this is equivalent to sampling from a hyperspherical shell, SD1S^{D-1} with radius Dσ\sqrt{D}\sigma and (fuzzy) thickness, δ\delta.

This follows from some straightforward algebra (dropping the superscript ll for simplicity):

E[w2]=E[i=1Dwi2]=i=1DE[wi2]=i=1Dσ2=Dσ2,\mathbb E[|\mathbf w|^2] = \mathbb E\left[\sum_{i=1}^D w_i^2\right] = \sum_{i=1}^D \mathbb E[w_i^2] = \sum_{i=1}^D \sigma^2 = D\sigma^2,

and

δ2var[w2]=E[(i=1Dwi2)2]E[i=1Dwi2]2=i,j=1DE[wi2wj2](Dσ2)2=ijDE[wi2]E[wj2]+i=1DE[wi4](Dσ2)2=D(D1)σ4+D(3σ4)(Dσ2)2=2Dσ4.\begin{align} \delta^2 \propto \mathrm{var} [|\mathbf w|^2] &= \mathbb E\left[\left(\sum_{i=1}^D w_i^2\right)^2\right] - \mathbb E\left[\sum_{i=1}^D w_i^2\right]^2 \\ &= \sum_{i, j=1}^D \mathbb E[w_i^2 w_j^2] - (D\sigma^2)^2 \\ &= \sum_{i \neq j}^D \mathbb E[w_i^2] \mathbb E[w_j^2] + \sum_{i=1}^D \mathbb E[w_i^4]- (D\sigma^2)^2 \\ &= D(D-1) \sigma^4 + D(3\sigma^4) - (D\sigma^2)^2 \\ &= 2D\sigma^4. \end{align}

So the thickness as a fraction of the radius is

δDσ=2DσD=2σ=2Din(l),\frac{\delta}{\sqrt{D}\sigma} = \frac{\sqrt{2D}\sigma}{\sqrt{D}} = \sqrt{2}\sigma = \frac{2}{\sqrt{D_\mathrm{in}^{(l)}}},

where the last equality follows from the choice of σ\sigma for Kaiming initialization.

This means that for suitably wide networks (Din(l)D_\mathrm{in}^{(l)} \to \infty), the thickness of this shell goes to 00.

Taking the thickness to 0

This suggests an alternative initialization strategy: sample directly from the boundary of a hypersphere with radius Dσ\sqrt{D}\sigma, i.e., modify the shell thickness to be 00.

This can easily be done by sampling each component from a normal distribution with mean 0 and variance 11 and then normalizing the resulting vector to have length Dσ\sqrt{D}\sigma (this is known as the Muller method).

Perturbing Weight initialization

Naïvely, if we're interested in a perturbation analysis of the choice of weight initialization, we prepare some baseline initialization, w0\mathbf w_0, and then apply i.i.d. Gaussian noise, δ\boldsymbol \delta, to each of its elements, δiN(0,ϵ2)\delta_i \sim \mathcal N(0, \epsilon^2).

The problem with this is that the perturbed weights w=w0+δ\mathbf w = \mathbf w_0 + \boldsymbol\delta are no longer sampled from the same distribution as the baseline weights. In terms of the geometric picture from the previous section, we're increasing the thickness of the hyperspherical shell in the vicinity of the baseline weights.

There is nothing wrong with this per se, but it introduces a possible confounder (the thickness).

The modification we made to Kaiming initialization was to sample directly from the boundary of a hypersphere, rather than from a hyperspherical shell. This is a more natural choice when conducting a perturbation analysis, because it makes it easier to ensure that the perturbed weights are sampled from the same distribution as the baseline weights.

Geometrically, the intersection of a hypersphere SDS^D of radius w0=w0w_0=|\mathbf w_0| with a hypersphere SDS^D of radius ϵ\epsilon that is centered at some point on the boundary of the first hypersphere, is a lower-dimensional hypersphere SD1S^{D-1} of a modified radius ϵ\epsilon'. If we sample uniformly from this lower-dimensional hypersphere, then the resulting points will follow the same distribution over the original hypersphere.

This suggests a procedure to sample from the intersection of the weight initialization hypersphere and the perturbation hypersphere.

First, we sample from a hypersphere of dimension D1D-1 and radius ϵ\epsilon' (using the same technique we used to sample the baseline weights). From a bit of trigonometry, see figure below, we know that the radius of this hypersphere will be ϵ=w0cosθ\epsilon' = w_0\cos \theta, where θ=cos1(1ϵ22w02)\theta = \cos^{-1}\left(1-\frac{\epsilon^2}{2w_0^2}\right).

400

Next, we rotate the vector so it is orthogonal to the baseline vector w0\mathbf w_0. This is done with a Householder reflection, HH, that maps the current normal vector n^=(0,,0,1)\hat{\mathbf n} = (0, \dots, 0, 1) onto w0\mathbf w_0:

H=I2ccTcTc,H = \mathbf I - 2\frac{\mathbf c \mathbf c^T}{\mathbf c^T \mathbf c},

where

c=n^+w^0,\mathbf c = \hat{\mathbf n} + \hat {\mathbf w}_0,

and w^0=w0w0\hat{\mathbf w}_0 = \frac{\mathbf w_0}{|w_0|} is the unit vector in the direction of the baseline weights.

Implementation note: For the sake of tractability, we directly apply the reflection via:

Hy=y2cTycTcc.H\mathbf y = \mathbf y - 2 \frac{\mathbf c^T \mathbf y}{\mathbf c^T\mathbf c} \mathbf c.

Finally, we translate the rotated intersection sphere along the baseline vector, so that its boundary goes through the intersection of the two hyperspheres. From the figure above, we find that the translation has the magnitude w0=w0cosθw_0' = w_0 \cos \theta.

By the uniform sampling of the intersection sphere and the uniform sampling of the baseline vector, we know that the resulting perturbed vector will have the same distribution as the baseline vector, when restricted to the intersection sphere.

sampling-perturation 1.png

Dynamical Systems

Formally, a dynamical system is a tuple (T,M,Φ)(\mathcal T, \mathcal M, \Phi), where T\mathcal T is the "time domain" (some monoid), M\mathcal M is the phase space over which the evolution takes place (some manifold), and Φ\Phi is the evolution function, a map,

Φ:T×MM,\Phi: \mathcal T \times \mathcal M \to \mathcal M,

that satisfies, xM,t1,t2T\forall x \in \mathcal M, \forall t_{1}, t_{2}\in \mathcal T,

Φ(0,x)=x,Φ(t2,Φ(t1,x))=Φ(t2+t1,x).\begin{align} \Phi(0, x) &= x,\\ \Phi(t_{2}, \Phi(t_{1}, x)) &= \Phi(t_{2}+t_{1}, x). \end{align}

4

Informally, we're usually interested in one of two perspectives:

  1. Trajectories of individual points in M\mathcal M, or
  2. Evolution of probability densities over M\mathcal M.

The former is described in terms of differential equations (for continuous systems) or difference equations (for discrete systems), i.e.,

xt=ϕ(x,t)\frac{\partial x}{\partial t} = \phi(x, t)

or

xt+1xt=ψ(xt,t)x_{t+1}- x_{t} = \psi(x_{t}, t)

The latter is described in terms of a transfer operator:

tp(x)=Lp(x).\nabla_{t} p(x) = \mathcal L p(x).

Both treatments have their advantages: the particle view admits easier empirical analysis (we just simulate a trajectory), while the latter

In terms of the former, path-sensitivity becomes the question of chaos: do nearby

Varying the Depth

For a fair comparison across network depths, we want to keep the total number of parameters DD constant as we increase depth.

Given nhn_h hidden layers with biases that each have width, ww, an input dimensionality of ninn_\mathrm{in}, and an output dimensionality of noutn_\mathrm{out}, the total number of parameters is

D(w,nh)=(nin+1) w+nh (w+1) w+(w+1) nout=nhw2+(nin+nh+nout)w+nout.\begin{align} D(w, n_h) &= (n_{\text{in}}+ 1)\ w + n_{h}\ (w +1)\ w + (w + 1)\ n_{\text{out}} \\ &= n_{h}w^{2}+ (n_\text{in} + n_{h}+n_\text{out})w + n_{\text{out}}. \end{align}

Lyapunov spectrum

The Lyapunov exponent λ\lambda quantifies the rate of separation of infinitesimally close trajectories:

δZ(t)eλtδZ0. |\delta \mathbf{Z}(t)| \approx e^{\lambda t}\left|\delta \mathbf{Z}_0\right|.

If the exponent is positive, then nearby trajectories will diverge, and the dynamics are chaotic. If the exponent is negative, then nearby trajectories will converge

For a DD-dimensional system, there are DD Lyapunov exponents. We often focus exclusively on the maximal Lyapunov exponent because it dominates the overall rate of separation between two neighboring trajectories. However, the full spectrum contains valuable additional information like the rate of entropy production, the fractal dimension, the Hausdorff dimension, and the Lyapunov dimension.

The first major limitation is that the full Lyapunov spectrum is intractable for large DD. The other major limitation is that the Lyapunov spectrum requires a suitable norm. We can use the l2 norm in W\mathcal W, but as we already saws what we really care about is measuring distance in F\mathcal F with some metric d:F×FRd : \mathcal F \times \mathcal F \to \mathbb R.

One option would be to repurpose the loss function \ell:

d(f1,f2)=1Ni=1N(f1(xi),f2(xi)), d(f_1, f_2) = \frac{1}{N} \sum_{i=1}^N \ell(f_1(x_i), f_2(x_i)),

Here, we treat f2(xi)f_2(x_i) as the truth, and use the empirical risk as our distance metric. We could define an analogous distance over either the train or test set. As long as \ell is a suitable metric (or easily converted into a metric through, e.g., symmetrization), dd will be too.

This suffers a major downside that two functions will be seen as identical as long as they have the same performance on the dataset in question. They can have totally different performance on unseen examples. Devising suitable distance metrics for neural networks requires more work.

Pasted image 20221213224220.png

Autocorrelations

The autocorrelation of a random process {Xt}\{X_t\} is the correlation between two values of that process at different times,

RXX(t1,t2)=Xt1Xt2XX. R_{XX}(t_1, t_2) = \langle X_{t_1} X_{t_2}\rangle_{XX}.

For stationary processes (where Xt\langle X_t\rangle is independent of tt), this becomes a function of one variable, τ=t2t1\tau=t_2-t_1,

RXX(τ)=XtXt+τXX. R_{XX}(\tau) = \langle X_tX_{t+\tau}\rangle_{XX}.

In the case of training, the dynamics aren't stationary: learning rate schedules and the the "descent" of gradient descent ensures that these correlations will depend on our choice of starting time. However, we're not typically interested in correlations between later time steps. We care about autocorrelations relative to the starting point t1=0t_1=0.

In practice, the probabilistic and point-wise views are intimately connected: the relevant autocorrelation timescales can often be directly related to the maximal Lyapunov component. These two formulas are only a small sample of the available tooling.

F\mathcal F is the space of pp-integrable functions, fwFf_w \in \mathcal F iff

Xfw(x)pdx<,\int_{\mathcal X} |f_{w}(x)|^{p}\, \mathrm dx < \infty,

which is equipped with a metric,

dF(f,g)=Xf(x)g(x)pdx.d_{\mathcal F}(f, g) = \int_{\mathcal X}|f(x)-g(x)|^{p }\, \mathrm d x.

So we exchange evolution over a finite-dimensional W\mathcal W with an infinite-dimensional

Footnotes

  1. For regression, Y=RN\mathcal Y = \mathbb R^N, for classification, YN\mathcal Y \subset \mathbb N, for self-supervised tasks, we're often interested in Y=X\mathcal Y = \mathcal X, etc. 2

  2. To start, we'll ignore the model hyperparameters. You can view this as either absorbing the hyperparameters into the weights, into the optimizer, or into our choice of ff. Later, it'll be useful to separate out hyperparameters (i.e., f:X×Y×HYf: \mathcal X \times \mathcal Y \times \mathcal H \to \mathcal Y).

  3. A quick note on notation: Often, you'll see the dataset denoted with a capital DD. Here, we're using a lowercase because it'll be useful to treat d\mathbf d as an instance of a random variable D\mathbf D. Similarly for xx, yy, and ww (XX, YY, and WW). Even though xx, yy, and ww are (typically) all vectors, we'll reserve boldface for sets. TODO: Maybe just bite the bullet and bold all of them.

  4. It's possible to generalize this further so that Φ:U(T×M)M\Phi : U \subseteq (\mathcal T \times \mathcal M) \to \mathcal M, but this won't be necessary for us.

Singular Learning Theory

NOTE: Work in Progress

Abstract


Introduction

Regular learning theory is lying to you: "overparametrized" models actually aren't overparametrized, and generalization is not just a question of broad basins.

600

The standard explanation for neural networks is that gradient descent settles in flat basins of the loss function. On the left, in a sharp minimum, the updates bounce the model around. Performance will vary wildly with new examples. On the right, in a flat minimum, the updates settle to zero. Performance is stable under small perturbations.

That's because loss basins actually aren't basins but valleys, and at the base of these valleys lie manifolds of constant, minimum loss. The higher the dimension of these "rivers", the lower the effective dimensionality of your model. Generalization is a balance between expressivity (more effective parameters) and simplicity (fewer effective parameters).

misc

Singular directions lower the effective dimensionality of your model. In this example, a line of degenerate points effectively restricts the two-dimensional loss surface to one dimension.

These manifolds correspond to the internal symmetries of NNs: continuous variations of a given network that perform the same calculation. Many of these symmetries are predetermined by the architecture and so are always present. We call these "generic". The more interesting symmetries are non-generic symmetries, which the model can form or break during training.

In this light, part of the power of NNs is that they can vary their effective dimensionality (thus also expressivity). Generality comes from a kind of "forgetting" in which the model throws out unnecessary dimensions. At the risk of being elegance-sniped, SLT seems like a promising route to develop a better understanding of training dynamics (and phenomenon such as sharp left turns and path-dependence). If we're lucky, SLT may even enable us to construct a grand unified theory of scaling.

A lot still needs to be done (esp. in terms of linking the Bayesian presentation of singular learning theory to conventional machine learning), but, from an initial survey, singular learning theory feels meatier than other explanations of generalization.1 So let me introduce you to the basics…

Singular Models

Maximum likelihood estimation is KL-divergence minimization.

We're aiming for a shallow introduction of questionable rigor. For full detail, I recommend Carroll's MSc thesis here (whose notation I am adopting).

The setting is Bayesian, so we'll start by translating the setup of "standard" regression problem to more appropriate Bayesian language.

We have some true distribution q(yx)q(y|x) and some model p(yx,w)p(y|x, w) parametrized by weights, wWRDw \in W \subseteq \mathbb R^D. Our aim is to learn the weights that make pp as "close" as possible to qq.

Given a dataset D=(xi,yi)i=1n\mathcal D = {(x_i, y_i)}*{i=1}^n, frequentist learning is usually formulated in terms of the empirical likelihood of our data (which assumes that each sample is i.i.d.):

p(yx,w)=Πi=1np(yixi,w). p(\mathbf y|\mathbf x, w) = \Pi*{i=1}^n p(y_i|x_i,w).

The aim of learning is to find the weights that maximize this likelihood (hence "maximum likelihood estimator"):

w=argmaxwp(yx,w). w^* = \text{argmax}_w\, p(\mathbf y|\mathbf x, w).

That is: we want to find the weights which make our observations as likely as possible.

In practice, because sums are easier than products and because we like our bits to be positive, we end up trying to minimize the negative log likelihood instead of the vanilla likelihood. That is, we're minimizing average bits of information rather than maximizing probabilities:

Ln(w):=1nlogp(yx,w)=1ni=1nlogp(yixi,w). L_n(w) := -\frac{1}{n}\log p(\mathbf y | \mathbf x, w) = -\frac{1}{n}\sum_{i=1}^n\log p(y_i|x_i, w).

If we define the empirical entropy, SnS_n, of the true distribution,

Sn:=1ni=1nlogq(yixi), S_n := -\frac{1}{n}\sum_{i=1}^n \log q(y_i|x_i),

then, since SnS_n is independent of ww, we find that minimizing Ln(w)L_n(w) is equivalent to minimizing the empirical Kullback-Leibler divergence, Kn(w)K_n(w), between our model and the true distribution:

Kn(w):=1ni=1nlogq(yixi)p(yixi,w)=Ln(w)Sn. K_n(w):= \frac{1}{n}\sum_{i=1}^n\log \frac{q(y_i|x_i)}{p(y_i|x_i, w)} = L_n(w) - S_n.

So maximizing the likelihood is not just some half-assed frequentist heuristic. It's actually an attempt to minimize the most straightforward information-theoretic "distance" between the true distribution and our model.

K(w):=DKL(q(y,x)p(y,xw))=Eq(y,x)[logq(yx)p(yx,w)]. K(w) := D_{\text{KL}}(q(y, x)||p(y, x | w)) = \mathbb E_{q(y, x)}\left[\log \frac{q(y|x)}{p(y|x, w)}\right].

The advantage of working with the KL-divergence is that it's bounded: K(w)0K(w) \geq 0 with equality iff q(yx,w)=p(yx)q(y|x, w) = p(y|x) almost everywhere.

In this frame, our learning task is not simply to minimize the KL-divergence, but to find the true parameters:

W0:={wWK(w)=0}={wWp(yx,w)=q(yx)}. W_0 := \{w \in W|K(w)=0\} = \{w \in W|p(y|x, w) = q(y|x)\}.

Note that it is not necessarily the case that a set of true parameters actually exists. If your model is insufficiently expressive, then the true model need not be realizable: your best fit may have some non-zero KL-divergence.

Still, from the perspective of generalization, it makes more sense to talk about true parameters than simply the KL-divergence-minimizing parameters. It's the true parameters that give us perfect generalization (in the limit of infinite data).

The Bayesian Information Criterion is a lie.

One of the main strengths of the Bayesian frame is that it lets enforce a prior φ(w)\varphi(w) over the weights, which you can integrate out to derive a parameter-free model:

p(yx)=Wp(yx,w)φ(w) dw. p(y|x) = \int_W p(y|x, w)\varphi(w)\ \text{d}w.

One of the main weaknesses is that this integral is often almost always intractable. So Bayesians make a concession to the frequentists with a much more tractable Laplace approximation (i.e., you approximate your model as quadratic/gaussian in the vicinity of the maximum likelihood estimator (MLE), w(0)w^{(0)}):2

K(w)12(ww(0))TI(w(0))(ww(0)), K(w) \approx \frac{1}{2}(w-w^{(0)})^T I(w^{(0)}) (w-w^{(0)}),

where I(w)I(w) is the Fisher information matrix:

Ij,k(w)=R(wjlogp(yx,w))(wklogp(yx,w))p(y,xw)dxdy. I_{j,k}(w)=\int_{\mathbb{R}}\left(\frac{\partial}{\partial w_j} \log p(y|x, w)\right)\left(\frac{\partial}{\partial w_k} \log p(y |x, w)\right) p(y, x|w) \text{d} x \text{d} y.

placeholder

The Laplace approximation is a probability theorist's Taylor approximation.

From this approximation, a bit more math gives us the Bayesian information criterion (BIC):

BIC=Ln(w0)+D2logn. \text{BIC} = L_n(w_0) + \frac{D}{2}\log n.

The BIC (like the related Akaike information criterion) is a criterion for model selection that penalizes complexity. Given two models, the one with the lower BIC tends to overfit less (/"generalize better").

The problem with regular learning theory is that deriving the BIC invokes the inverse, I1(w(0))I^{-1}(w^{(0)}), of the information matrix. If I(w(0))I(w^{(0)}) is non-invertible, then the BIC and all the generalization results that depend on it are invalid.

As it turns out, information matrices are pretty much never invertible for deep neural networks. So, we have to rethink our theory.

Singularities in the context of Algebraic Geometry

For an analytic function K:WR,xWK : W \to \mathbb R, x \in W is a critical point of KK if it has zero divergence, K(x)=0\nabla K(x) = 0. A singularity is a critical point that is also equal to zero, K(x)=0K(x) = 0.

Under these definitions, any true parameter ww^* is a singularity of the KL divergence. K(w)=0K(w^*)=0 follows from the definition of ww^*, and K(w)\nabla K(w^*) follows from the lower bound, K(w)0K(w) \geq 0.

So another advantage of the KL divergence over the NLL is that it gives us a cleaner lower bound, under which K(w)K(w^*) is a singularity for any true parameter ww^*.

We are interested in degenerate singularities — singularities that occupy a common manifold. For degenerate singularities, there is some continuous change to ww^* which leaves K(w)K(w^*) unchanged. That is, the surface is not locally parabolic.

placeholder

Non-degenerate singularities are locally parabolic. Degenerate singularities are not.

In terms of KK, this means that the Hessian at the singularity has at least one zero eigenvalue (equivalently, it is non-invertible). For the KL-divergence, the Hessian at a true parameter is precisely the Fisher information matrix we just saw.

Generic symmetries of NNs

Neural networks are full of symmetries. that let you change the model's internals without changing the overall computation. This is where our degenerate singularities come from.

The most obvious symmetry is that you can permute weights without changing the overall computation. Given any compatible two linear transformations, AA and BB (i.e., weight matrices), an element-wise activation function, ϕ\phi, and any permutation, PP,

BϕA=(BP1)ϕ(PA) B \circ \phi \circ A = (B \circ P^{-1}) \circ \phi \circ (P \circ A)

because permutations commute with ϕ\phi. The non-linearity of ϕ\phi means this isn't the case for invertible transformations in general.

(abcdefghi)(jklmnopqr)=(bacedfhgi)(mnojklpqr) \begin{pmatrix} \color{red} a & \color{blue} b & c\\ \color{red} d & \color{blue}e & f \\ \color{red}g & \color{blue} h & i \end{pmatrix} \cdot \begin{pmatrix} \color{red}j & \color{red}k & \color{red}l\\ \color{blue}m & \color{blue}n & \color{blue}o \\ p & q & r \end{pmatrix} =\begin{pmatrix} \color{blue} b & \color{red} a & c\\ \color{blue} e & \color{red}d & f \\ \color{blue}h & \color{red} g & i \end{pmatrix} \cdot \begin{pmatrix} \color{blue}m & \color{blue}n & \color{blue}o \\ \color{red}j & \color{red}k & \color{red}l\\ p & q & r \end{pmatrix}

An example using the identity function for ϕ\phi.

At least for this post, we'll ignore this symmetry as it is discrete, and we're interested in continuous symmetries that can give us degenerate singularities.

A more promising continuous symmetry is the following (for models that use ReLUs):

ReLU(x)=1αReLU(αx),α>0. \text{ReLU}(x) =\frac{1}{\alpha}\text{ReLU}(\alpha x),\quad \alpha > 0.

For a ReLU layer, you can continuously scale the incoming pre-activation as long as you inversely scale the outgoing activation. Since this symmetry is present over the entire parameter space, WW, nowhere in the weight space is safe from degeneracy.

As an exercise for the reader, can you think of any other generic symmetries in deep neural networks?3

Both of these symmetries are generic. They're an after-effect of the architecture choice, and are always active. The more interesting symmetries are non-generic symmetries — those that depend on ww.

Non-Generic Symmetries

The key observation of singular learning theory for neural networks is that neural networks can vary their effective parameter count.

Real Log Canonical Threshold (RLCT)

Zeta function of K(w)K(w)

ζ(z)=WK(w)zϕ(w)dw. \zeta(z)=\int_W K(w)^z \phi(w)\,\text{d}w.

where wW:ϕ(w)>0\forall w \in W: \phi(w)>0.

Analytically continue this to the whole complex plane with a Laurent expansion, then the first (large) pole is the RLCT.

Missing good image of what is going on here. Yes, the pole is the location of a singularity (ζ(z)\zeta(z) \to \infty) and its multiplicity is the order of the polynomial you need to approximate the local behavior of the corresponding zero of ζ1\zeta^{-1}. But how do I actually interpret zz?

Real log canonical threshold (λ\lambda)

  • The RLCT is the volume co-dimension (the number of effective parameters near the most singular point W0W_0).
  • For regular (non-singular) models, the RLCT is precisely D/2D/2
  • Why divide by two?

Orientation-reversing symmetries

Using ReLUs, we can imagine our network performing a kind of piece-wise high-dimensional splice approximation of the input function. For higher dimensions, we're looking at constant hypersurfaces.

The intersections of these surfaces are described by a linear equation of the parameters that sum to zero. That is, there's an orientation. If we reverse this orientation, we get the same line.

Degenerate nodes

In addition to these symmetries, when the model has more hidden nodes than truth, excess nodes are either degenerate or have the same activation boundary as another one.

NOTE: I'm still confused here.

INSERT comment on equivariance (link Distill article)


Glossary

  • If these parameters exist at all (i.e., this set is non-empty, and there is some choice of weights w0w_0 for which K(w0)=0K(w_0) = 0), we say our model is realizable. We'll assume this is the case from now on.
  • When every choice of weights corresponds to a unique model (i.e., the map θq(yx,θ)\theta \mapsto q(y|x, \theta) is injective for all x,yx, y), we say our model is identifiable.
  • If a model is identifiable and its Fisher information matrix is positive definite, then a model is regular. Otherwise, the model is strictly singular.

Singular learning theory kicks into gear when our models are singular. When the true parameters are

If the Hessian is always strictly positive definite (it has no zero eigenvalues for any θ\boldsymbol\theta), then an identifiable model is called regular. A non-regular model is called strictly singular.

Example: Pendulum

Our object of study are triples of the kind (p(yx,w),q(yx),ϕ(w))(p(y|x,w), q(y|x), \phi(w)), where p(yx,w)p(y|x, w) is a model p,1 of some unknown true model, q(yx)q(y|x), with a prior over the weights, ϕ(w)\phi(w).

The model itself is a regression model on ff:

p(yx,w)N(yf(x,y),1M) p(y|x, w) \sim \mathcal{N}(y|f(x,y), \mathbb 1_M)

We have some probability distribution p(yx)p(y|x) and a model q(yx,θ)q(y|x, \theta) parameterized by θ\theta. Our aim is to learn the weights θ\theta so as to capture the true distribution, and we assume p(x,y)p(x, y) is realizable

Let's start with an example to understand how learning changes when your models are singular.

You have a pendulum with some initial angular displacement x0x_0 and velocity v0v_0. Newton tells us that at time tt, it'll be in position

xt=f(x0,v0,g,t)=x0cos(gt)+v0gsin(gt). x_t = f(x_0, v_0, g, t) = x_0 \cos(\sqrt{g}\cdot t) + \frac{v_0}{\sqrt{g}}\sin(\sqrt{g}\cdot t).

The problem is that we live in the real world. Our measurement of xtx_t is noisy:

p(xtx0,v0,g,t)=N(xtf(x0,v0,g,t), ϵ2) p(x_t|x_0, v_0, g, t) = \mathcal N(x_t|f(x_0, v_0, g, t),\ \epsilon^2)

What we'd like to do is to learn the "true" parameters (x0,v0,g,t)(x_0^*, v_0^*, g^*, t^*) from our observations xtx_t. That gives us a problem: For any set of true parameters, the following would output the same value:

{(x0,av,a2g,1at)  a>0} \{(x_0, av^*, a^2 g^*, \frac{1}{a}t^*)\ |\ a > 0\}

That is: our map from parameters to models is non-injective. Multiple parameters determine the same model. We call these models strictly singular.

  • Aim: Modeling a pendulum given a noisy estimate of (x,y)(x, y) at time tt.

  • The parameters of our model are (λ,g,t)(\lambda, g, t) (initial velocity, gravitational acceleration, time of measurement), and the model is:

  • The map from parameters to models is non-injective. That is, the function, ff, is exactly the same after a suitable mapping (like g4gg \to 4g, λ2λ\lambda \to 2 \lambda, tt/2t \to t/2).
  • You can reparameterize this model to get rid of the degeneracy (λ=λ/g\lambda' = \lambda/\sqrt{g}, t=gtt' = \sqrt g \cdot t):
f(x,y;λ,t)=xcos(t)+λsin(t).f(x, y; \lambda', t') = x \cos(t') + \lambda'\sin(t').
  • But that may actually make the parameters less useful to reason about, and, in general, may make the "true" model harder to find.

(If you look at the tλt - \lambda plane, you get straight line level sets, same for tgt - \sqrt g plane).

Relation to Hessian of K(w)K(w) (KL-divergence between a parameter and the true model)

…TODO

Connection to basin broadness

…TODO

Emphasize this early on.

Question

  • Does SGD actually reach these solutions? For a given loss, if we are uniformly distributed across all weights with that loss, we should end up in simpler solutions, right? Does this actually happen though?
  • Is part of the value of depth that you create more ReLU like symmetries. Can you create equally successful shallow, wide models if you hardcode additional symmetries?

Footnotes

  1. E.g.: that explicit regularization enforces simpler solutions (weight decay is a Gaussian prior over weights), that SGD settles in broader basins that are more robust to changes in parameters (=new samples), that NNs have Solomonoff-like inductive biases [1], or that highly correlated weight matrices act as implicit regularizers [2]. 2

  2. This is just a second-order Taylor approximation modified for probability distributions. That is, the Fisher information matrix gives you the curvature of the negative log likelihood: it tells you how many bits you gain (=how much less likely your dataset becomes) as you move away from the minimum in parameter space.

  3. Hint: normalization layers, the encoding/unencoding layer of transformers / anywhere else without a privileged basis.