88
99This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper
1010 [Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://papers.labml.ai/paper/2305.14342).
11+ Official implementation is available at [Liuhong99/Sophia](https://github.com/Liuhong99/Sophia).
12+
13+ Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant
14+ to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost
15+ pre-conditioner.
16+
17+ Sophia keeps diagonal Hessian estimates with EMA across iterations.
18+ The diagonal Hessian $\hat{h}_t$ is calculated every $k$ steps.
19+
20+ \b egin{align}
21+ h_t = \b eta_2 h_{t-k} + (1 - \b eta_2) \hat{h}_t \ \ \ \ \t ext{ if } t \t ext{ mod } k = 1; \t ext{ else } h_t = h_{t-1}
22+ \end{align}
23+
24+ Sophia uses EMA of gradients $m_t$, only considers positive entries of
25+ the diagonal Hessian and does per-coordinate clipping to the update.
26+
27+ \b egin{align}
28+ m_t &\leftarrow \b eta_1 m_{t-1} + (1 - \b eta_1)g_t \\
29+ \t heta_{t + 1} &\leftarrow \t heta_t - \eta \cdot \operatorname{clip} \b igg(\f rac{m_t}{ \max \{h_t, \epsilon \} }, \r ho \b igg)
30+ \end{align}
31+
32+ where $\epsilon$ is a very small value to prevent division by $0$.
33+
34+ ### Gauss-Newton-Bartlett (GNB) estimator
35+
36+ \b egin{align}
37+ \hat{L}(\t heta) &= \f rac{1}{B} \sum^{B}_{b=1} \ell_{CE} \b ig( f(\t heta, x_b), \hat{y}_b \b ig) \\
38+ \hat{h}_t &= B \cdot \n abla_\t heta \hat{L} (\t heta) \odot \n abla_\t heta \hat{L} (\t heta)
39+ \end{align}
40+
41+ where $x_b$ are the inputs,
42+ $B$ is the batch size (number of inputs/tokens),
43+ $\ell_{CE}$ is cross entropy loss, and
44+ $\hat{y}_b$ are sampled from the logits $f(\t heta, x_b)$.
45+
46+ Note that this hessian estimate is always positive and therefore we
47+ can replace $\max \{h_t, \epsilon \}$ with $h_t + \epsilon$.
48+
49+ Sophia with Gauss-Newton-Bartlett (GNB) estimator is **Sophia-G**
50+
51+ Here is an [experiment](../transformers/basic/with_sophia.html) that uses Sophia-G to train a transformer.
1152"""
1253
1354from typing import Dict , Any , Tuple , Optional
@@ -27,15 +68,15 @@ class Sophia(GenericAdaptiveOptimizer):
2768 """
2869
2970 def __init__ (self , params ,
30- lr : float = 1e-4 , betas : Tuple [float , float ] = (0.965 , 0.99 ), eps : float = 1e-16 ,
31- rho : float = 0.04 ,
71+ lr : float = 1e-4 , betas : Tuple [float , float ] = (0.9 , 0.95 ), eps : float = 1e-12 ,
72+ rho : float = 0.03 ,
3273 weight_decay : WeightDecay = WeightDecay (),
3374 defaults : Optional [Dict [str , Any ]] = None ):
3475 """
3576 ### Initialize the optimizer
3677
3778 * `params` is the list of parameters
38- * `lr` is the learning rate $\a lpha $
79+ * `lr` is the maximum learning rate $\eta \r ho $
3980 * `betas` is a tuple of ($\b eta_1$, $\b eta_2$)
4081 * `eps` is $\epsilon$
4182 * `pho` is $\r ho$
@@ -61,23 +102,46 @@ def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Par
61102
62103 # This is the number of optimizer steps taken on the parameter, $t$
63104 state ['step' ] = 0
64- # state['hessian_updates']
65105 # Exponential moving average of gradients, $m_t$
66106 state ['exp_avg' ] = torch .zeros_like (param , memory_format = torch .preserve_format )
67- # Exponential moving average of Hessian
107+ # Exponential moving average of Hessian diagonal, $h_t$
68108 state ['hessian' ] = torch .zeros_like (param , memory_format = torch .preserve_format )
69109
70110 def update_hessian (self , n_tokens_training_batch ):
111+ """
112+ ### Update the EMA of Hessian diagonal $h_t$
113+
114+ * `n_tokens_training_batch` is the number of tokens/inputs in the batch $B$
115+
116+ \b egin{align}
117+ \hat{h}_t &= B \cdot \n abla_\t heta \hat{L} (\t heta) \odot \n abla_\t heta \hat{L} (\t heta) \\
118+ h_t &= \b eta_2 h_{t-k} + (1 - \b eta_2) \hat{h}_t
119+ \end{align}
120+ """
121+
122+ # Iterate through parameter groups
71123 for group in self .param_groups :
72- beta1 , beta2 = group ['betas' ]
124+ # $\beta_2$
125+ _ , beta2 = group ['betas' ]
126+ # Iterate through parameters
73127 for p in group ['params' ]:
128+ # Skip parameters without gradients
74129 if p .grad is None :
75130 continue
131+
132+ # Get optimizer state
76133 state = self .state [p ]
77134
135+ # Initialize state if empty
78136 if len (state ) == 0 :
79137 self .init_state (state , group , p )
80138
139+ # Update EMA Hessian diagonal
140+ #
141+ # \begin{align}
142+ # \hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
143+ # h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
144+ # \end{align}
81145 state ['hessian' ].mul_ (beta2 ).addcmul_ (p .grad , p .grad , value = (1 - beta2 ) * n_tokens_training_batch )
82146
83147 def step_param (self , state : Dict [str , any ], group : Dict [str , any ], grad : torch .Tensor , param : torch .nn .Parameter ):
@@ -88,17 +152,24 @@ def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.T
88152 * `group` stores optimizer attributes of the parameter group
89153 * `grad` is the current gradient tensor $g_t$ for the parameter $\t heta_{t-1}$
90154 * `param` is the parameter tensor $\t heta_{t-1}$
155+
156+ We do the following parameter update,
157+
158+ \b egin{align}
159+ \t heta_{t + 1} &\leftarrow \t heta_t - \eta \cdot \operatorname{clip} \b igg(\f rac{m_t}{h_t + \epsilon}, \r ho \b igg) \\
160+ \t heta_{t + 1} &\leftarrow \t heta_t - \eta \r ho \cdot \operatorname{clip} \b igg(\f rac{m_t}{\r ho h_t + \epsilon}, 1 \b igg)
161+ \end{align}
91162 """
92163
93164 # Calculate weight decay
94165 grad = self .weight_decay (param , grad , group )
95166
96167 # Get $\beta_1$ and $\beta_2$
97168 beta1 , beta2 = group ['betas' ]
98-
169+ # Get $\rho$
99170 rho = group ['rho' ]
100171
101- # Get $m_{t-1}$ and $v_{t-1 }$
172+ # Get $m_{t-1}$ and $h_{t }$
102173 m , hessian = state ['exp_avg' ], state ['hessian' ]
103174
104175 # In-place calculation of $m_t$
@@ -108,9 +179,11 @@ def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.T
108179 # Increment $t$ the number of optimizer steps
109180 state ['step' ] += 1
110181
111- # Get learning rate
182+ # Get maximum learning rate $\eta \rho$
112183 lr = group ['lr' ]
113184
114- ratio = (m .abs () / (rho * hessian + group ['eps' ])).clamp (None , 1 )
185+ # $$\operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)$$
186+ ratio = (m / (rho * hessian + group ['eps' ])).clamp (- 1 , 1 )
115187
116- param .data .addcmul_ (m .sign (), ratio , value = - lr )
188+ # $$\theta_{t + 1} \leftarrow \theta_t - \eta \rho \cdot \operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)$$
189+ param .data .add_ (ratio , alpha = - lr )
0 commit comments