# The edge of stability (Wed, Mar 08)

## Self-stabilization dynamics

• Define the set $\mathcal{M} \seteq \{ \theta \in \R^n : S(\theta) \leq 2/\eta \textrm{ and } \langle \nabla L(\theta), u(\theta)\rangle = 0 \}$, where we recall that $u(\theta)$ is the top eigenvector of $\nabla^2 L(\theta)$.

Given the gradient descent dynamics $\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t)$, we define $\hat{\theta}_t \seteq \mathrm{proj}_{\mathcal{M}}(\theta_t)$ as the (Euclidean) projection to the convex set $\mathcal{M}$.

• Now define the quantities

\begin{align*} x_t &\seteq \langle u, \theta_t - \hat{\theta}_0\rangle \\ y_t &\seteq \langle \nabla S(\hat{\theta}_0), \theta_t - \hat{\theta}_0\rangle, \end{align*}

where $u \seteq u(\hat{\theta}_0)$ is the top eigenvector of $\nabla^2 L(\hat{\theta}_0)$ and $S(\theta) \seteq \lambda_{\max}(\nabla^2 L(\theta))$ is the sharpness.

If we assume that $S(\hat{\theta}_0)=2/\eta$ and $\nabla S(\theta_t) \approx \nabla S(\hat{\theta}_0)$, then

$$$\label{eq:sharp} S(\theta_t) \approx 2/\eta + y_t\,.$$$
• Linear approximation. For $x_t,y_t$ small, we can approximate $\nabla L(\theta_t) \approx \nabla L(\hat{\theta}_0)$.

• Quadratic approximation. When $\abs{x_t}$ blows up, we use a quadratic approximation:

\begin{align*} \nabla L(\theta_t) - \nabla L(\hat{\theta}_0) &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) \\ &\approx \langle u, \theta_t - \hat{\theta}_0\rangle \nabla^2 L(\theta_t) u\,. \end{align*}

In the last line, we have used the fact that $\theta_t-\hat{\theta}_0$ only blows up in the $u$ direction, so we only use the quadratic approximation there.

This quantity is equal to $x_t \nabla^2 L(\theta_t) u \approx x_t S(\theta_t) u$.

In conjunction with $\langle u, \nabla L(\hat{\theta}_0)\rangle = 0$, this gives

$$$\label{eq:xpart} x_{t+1} = x_t - \eta \langle u, \nabla L(\theta_t)\rangle = x_t(1-\eta S(\theta_t)) \approx -(1+\eta y_t) x_t\,,$$$

where the last approximation uses \eqref{eq:sharp}. Here we see the blowup phase where $\abs{x_t}$ grows multiplicatively as soon as $y_t > 0$, i.e., as soon as the sharpness exceeds $2/\eta$, and the sign of $x_t$ oscillates.

• Cubic terms. For $\abs{x_t}$ sufficiently large, we need to incorporate the 3rd-order term in the Taylor expansion:

\begin{align*} \nabla L(\theta_t) - \nabla L(\hat{\theta}_0) &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) + \frac12 \nabla^3 L(\theta_t)[\theta_t-\hat{\theta}_0,\theta_t-\hat{\theta}_0] \\ &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) + \frac12 \nabla^3 L(\hat{\theta}_0)[\theta_t-\hat{\theta}_0,\theta_t-\hat{\theta}_0] \\ &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) + \frac12 \nabla^3 L(\hat{\theta}_0)[u,u] \langle u, \theta_t - \hat{\theta}_0\rangle^2\,, \end{align*}

where as in the quadratic setting, we only consider the component in direction $u$. Using the approximation we derived for the quadratic part, this gives

\begin{align*} \nabla L(\theta_t) - \nabla L(\hat{\theta}_0) &\approx x_t S(\theta_t) u + \frac{x_t^2}{2} \nabla^2 L(\hat{\theta}_0)[u,u] \\ &=x_t S(\theta_t) u + \frac{x_t^2}{2} \nabla S(\hat{\theta}_0)\,, \end{align*}

where we used the fact that if the top eigenvalue of $\nabla^2 L(\theta)$ has multiplicity one, then $\nabla S(\theta) = \nabla^3 L(\theta)[u(\theta),u(\theta)]$.

• Then our update to $y_t$ looks like

\begin{align*} y_{t+1} - y_t &= \langle \nabla S(\hat{\theta}_0), - \eta \nabla L(\theta_t)\rangle \\ &= -\eta \langle \nabla S(\hat{\theta}_0), \nabla L(\hat{\theta}_0)\rangle - \eta \langle \nabla S(\hat{\theta}_0), \nabla L(\theta_t)-\nabla L(\hat{\theta}_0)\rangle \\ &\approx \eta \alpha - \eta x_t S(\theta_t) \langle \nabla S(\hat{\theta}_0), u \rangle - \eta \frac{x_t^2}{2} \beta\,, \end{align*}

where $\alpha \seteq \langle \nabla S(\hat{\theta}_0), - \nabla L(\hat{\theta}_0)\rangle$ is the progressive sharpening parameter and $\beta \seteq \|\nabla S(\hat{\theta}_0)\|^2$ is the “stabilization” parameter.

• If we assume for simplicity that $\langle \nabla S(\hat{\theta}_0), u\rangle = 0$, then we obtain

$$$\label{eq:ypart} y_{t+1} - y_t = \eta \left(\alpha - \beta \frac{x_t^2}{2}\right)\,,$$$

meaning that the sharpness is decaying whenever $x_t > \sqrt{2\alpha/\beta}$.