First I thought this would be just another gradient descent tutorial for beginners. But the article goes quite deep into gradient descent dynamics, looking into third order approximations of the loss function and eventually motivating a concept called "central flows." Their central flow model was able to predict loss graphs for various training runs across different neural network architectures.
Fascinating, do the gained insights allow to directly compute the central flow in order to speed up convergence? Or is this preliminary exploration to understand how it had been working?
They explicitly ignore momentum and exponentially weighted moving average, but that should result in the time-averaged gradient descent (along the valley, not across it). But that requires multiple evaluations, do any of the expressions for the central flow admit fast / computationally efficient central flow calculation?
> We emphasize that the central flow is a theoretical tool for understanding optimizer behavior, not a practical
optimization method. In practice, maintaining an exponential moving average of the iterates (e.g., Morales-Brotons
et al., 2024) is likely a computational feasible way to estimate the optimizer’s time-averaged trajectory.
They analyze the behavior of RMSProp (Adam without momentum) using their framework to come up with simplified mathematical models that are able to predict actual training behavior in experiments. It looks like their mathematical models explain why RMSProp works, in a way that is more satisfying than the usual hand waving explanations.
Yes, it certainly provides a lot more clarity than the handwaving.
While momentum seems to work, and the authors clearly state it is not intended as a practical optimization method, I can't exclude that we can improve convergence rates by building on this knowledge.
Is it guaranteed for the oscillating behavior to have a period of 2 steps? or is say 3 step period also possible (a vector in a plane could alternately point to 0 degrees, 120 degrees and 240 degrees).
The way I read this presentation the implication seems to be that its always a period of 2. Perhaps if the top-2 sharpnesses are degenerate (identical), a period of N distinct from 2 could be possible?
It makes you wonder what if instead of storing momentum with exponential moving average one were to use the average of the last 2 iterations, so there would be less lag.
It also makes me wonder if we should perform 2 iterative steps PER sequence so that the single-sample-sequence gives feedback along it's valley instead of across it. One would go through the corpus at half the speed, but convergence may be more accurate.
So all the classic optimization theory about staying in the stable region is basically what deep learning doesn't do. The model literally learns by becoming unstable, oscillating, and then using that energy to self-correct.
The chaos is the point. What a crazy, beautiful mess.
Researchers are constantly looking to train more expressive models more quickly. Any method which can converge + take large jumps will be chosen. You are sort of guaranteed to end up in a place where the sharpness is high but it somehow comes under control. If we weren't there.... we'd try a new architecture until we arrived there. So deep learning doesn't "do this", we do it using any method possible and it happened to be the various architectures that currently fit into "deep learning". Keep in mind many architectures which are deep do not converge - you see survivorship bias.
Reminds me of Simulated Annealing. Some randomness have always been part of optimization processes that seek a better equilbrium than local. Genetic Algorithms have mutation, Simulated Annealing has temperature, Gradient Descent similarly has random batches.
Very neat stuff! So one question is, if we had an analog computer that could run these flows exactly, would we get better results if we ran the gradient flow or this central flow?
It's a little easier to see what's happening if you fully write out the central flow:
-1/η * dw/dt = ∇L - ∇S * ⟨∇L, ∇S⟩/‖∇S‖²
We're projecting the loss gradient onto the sharpness gradient, and subtracting it off. If you didn't read the article, the sharpness S is the sum of the eigenvalues of the Hessian of the loss that are larger than 2/η, a measure of how unstable the learning dynamics are.
This is almost Sobolev preconditioning:
-1/η * dw/dt = ∇L - ∇S = ∇(I - Δ)L
where this time S is the sum of all the eigenvalues (so, the Laplacian of L).
Yeah, I did a lot of traditional optimization problems during my Ph. D., this type of expression pops up all the time with higher-order gradient-based methods. You rescale or otherwise adjust the gradient based on some system-characteristic eigenvalues to promote convergence without overshooting too much.
I apparently didn't get the memo and used stochastic gradient descent with momentum outside of deep learning without running into any problems given a sufficiently low learning rate.
I'm not really convinced that their explanation truly captures why this success should be exclusive to deep learning.
reply