Define the set \(\mathcal{M} \seteq \{ \theta \in \R^n : S(\theta) \leq 2/\eta \textrm{ and } \langle \nabla L(\theta), u(\theta)\rangle = 0 \}\), where we recall that \(u(\theta)\) is the top eigenvector of \(\nabla^2 L(\theta)\).
Given the gradient descent dynamics \(\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t)\), we define \(\hat{\theta}_t \seteq \mathrm{proj}_{\mathcal{M}}(\theta_t)\) as the (Euclidean) projection to the convex set \(\mathcal{M}\).
Now define the quantities
\[\begin{align*} x_t &\seteq \langle u, \theta_t - \hat{\theta}_0\rangle \\ y_t &\seteq \langle \nabla S(\hat{\theta}_0), \theta_t - \hat{\theta}_0\rangle, \end{align*}\]where \(u \seteq u(\hat{\theta}_0)\) is the top eigenvector of \(\nabla^2 L(\hat{\theta}_0)\) and \(S(\theta) \seteq \lambda_{\max}(\nabla^2 L(\theta))\) is the sharpness.
If we assume that \(S(\hat{\theta}_0)=2/\eta\) and \(\nabla S(\theta_t) \approx \nabla S(\hat{\theta}_0)\), then
\[\begin{equation}\label{eq:sharp} S(\theta_t) \approx 2/\eta + y_t\,. \end{equation}\]Linear approximation. For \(x_t,y_t\) small, we can approximate \(\nabla L(\theta_t) \approx \nabla L(\hat{\theta}_0)\).
Quadratic approximation. When \(\abs{x_t}\) blows up, we use a quadratic approximation:
\[\begin{align*} \nabla L(\theta_t) - \nabla L(\hat{\theta}_0) &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) \\ &\approx \langle u, \theta_t - \hat{\theta}_0\rangle \nabla^2 L(\theta_t) u\,. \end{align*}\]In the last line, we have used the fact that \(\theta_t-\hat{\theta}_0\) only blows up in the \(u\) direction, so we only use the quadratic approximation there.
This quantity is equal to $x_t \nabla^2 L(\theta_t) u \approx x_t S(\theta_t) u$.
In conjunction with \(\langle u, \nabla L(\hat{\theta}_0)\rangle = 0\), this gives
\[\begin{equation}\label{eq:xpart} x_{t+1} = x_t - \eta \langle u, \nabla L(\theta_t)\rangle = x_t(1-\eta S(\theta_t)) \approx -(1+\eta y_t) x_t\,, \end{equation}\]where the last approximation uses \eqref{eq:sharp}. Here we see the blowup phase where \(\abs{x_t}\) grows multiplicatively as soon as \(y_t > 0\), i.e., as soon as the sharpness exceeds \(2/\eta\), and the sign of \(x_t\) oscillates.
Cubic terms. For \(\abs{x_t}\) sufficiently large, we need to incorporate the 3rd-order term in the Taylor expansion:
\[\begin{align*} \nabla L(\theta_t) - \nabla L(\hat{\theta}_0) &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) + \frac12 \nabla^3 L(\theta_t)[\theta_t-\hat{\theta}_0,\theta_t-\hat{\theta}_0] \\ &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) + \frac12 \nabla^3 L(\hat{\theta}_0)[\theta_t-\hat{\theta}_0,\theta_t-\hat{\theta}_0] \\ &\approx \nabla^2 L(\theta_t) (\theta_t-\hat{\theta}_0) + \frac12 \nabla^3 L(\hat{\theta}_0)[u,u] \langle u, \theta_t - \hat{\theta}_0\rangle^2\,, \end{align*}\]where as in the quadratic setting, we only consider the component in direction \(u\). Using the approximation we derived for the quadratic part, this gives
\[\begin{align*} \nabla L(\theta_t) - \nabla L(\hat{\theta}_0) &\approx x_t S(\theta_t) u + \frac{x_t^2}{2} \nabla^2 L(\hat{\theta}_0)[u,u] \\ &=x_t S(\theta_t) u + \frac{x_t^2}{2} \nabla S(\hat{\theta}_0)\,, \end{align*}\]where we used the fact that if the top eigenvalue of \(\nabla^2 L(\theta)\) has multiplicity one, then \(\nabla S(\theta) = \nabla^3 L(\theta)[u(\theta),u(\theta)]\).
Then our update to \(y_t\) looks like
\[\begin{align*} y_{t+1} - y_t &= \langle \nabla S(\hat{\theta}_0), - \eta \nabla L(\theta_t)\rangle \\ &= -\eta \langle \nabla S(\hat{\theta}_0), \nabla L(\hat{\theta}_0)\rangle - \eta \langle \nabla S(\hat{\theta}_0), \nabla L(\theta_t)-\nabla L(\hat{\theta}_0)\rangle \\ &\approx \eta \alpha - \eta x_t S(\theta_t) \langle \nabla S(\hat{\theta}_0), u \rangle - \eta \frac{x_t^2}{2} \beta\,, \end{align*}\]where \(\alpha \seteq \langle \nabla S(\hat{\theta}_0), - \nabla L(\hat{\theta}_0)\rangle\) is the progressive sharpening parameter and \(\beta \seteq \|\nabla S(\hat{\theta}_0)\|^2\) is the “stabilization” parameter.
If we assume for simplicity that \(\langle \nabla S(\hat{\theta}_0), u\rangle = 0\), then we obtain
\[\begin{equation}\label{eq:ypart} y_{t+1} - y_t = \eta \left(\alpha - \beta \frac{x_t^2}{2}\right)\,, \end{equation}\]meaning that the sharpness is decaying whenever \(x_t > \sqrt{2\alpha/\beta}\).