Skip to content

Commit 8db330d

Browse files
committed
sophia-g docs
1 parent 7c02294 commit 8db330d

6 files changed

Lines changed: 1389 additions & 99 deletions

File tree

docs/optimizers/configs.html

Lines changed: 86 additions & 76 deletions
Large diffs are not rendered by default.

docs/optimizers/sophia.html

Lines changed: 489 additions & 0 deletions
Large diffs are not rendered by default.

docs/sitemap.xml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@
687687

688688
<url>
689689
<loc>https://nn.labml.ai/optimizers/configs.html</loc>
690-
<lastmod>2021-10-21T16:30:00+00:00</lastmod>
690+
<lastmod>2023-07-14T16:30:00+00:00</lastmod>
691691
<priority>1.00</priority>
692692
</url>
693693

@@ -755,6 +755,13 @@
755755
</url>
756756

757757

758+
<url>
759+
<loc>https://nn.labml.ai/optimizers/sophia.html</loc>
760+
<lastmod>2023-07-14T16:30:00+00:00</lastmod>
761+
<priority>1.00</priority>
762+
</url>
763+
764+
758765
<url>
759766
<loc>https://nn.labml.ai/optimizers/amsgrad.html</loc>
760767
<lastmod>2023-04-02T16:30:00+00:00</lastmod>
@@ -965,6 +972,13 @@
965972
</url>
966973

967974

975+
<url>
976+
<loc>https://nn.labml.ai/transformers/basic/with_sophia.html</loc>
977+
<lastmod>2023-07-14T16:30:00+00:00</lastmod>
978+
<priority>1.00</priority>
979+
</url>
980+
981+
968982
<url>
969983
<loc>https://nn.labml.ai/transformers/basic/index.html</loc>
970984
<lastmod>2021-06-07T16:30:00+00:00</lastmod>

docs/transformers/basic/with_sophia.html

Lines changed: 690 additions & 0 deletions
Large diffs are not rendered by default.

labml_nn/optimizers/sophia.py

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,47 @@
88
99
This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper
1010
[Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://papers.labml.ai/paper/2305.14342).
11+
Official implementation is available at [Liuhong99/Sophia](https://github.com/Liuhong99/Sophia).
12+
13+
Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant
14+
to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost
15+
pre-conditioner.
16+
17+
Sophia keeps diagonal Hessian estimates with EMA across iterations.
18+
The diagonal Hessian $\hat{h}_t$ is calculated every $k$ steps.
19+
20+
\begin{align}
21+
h_t = \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t \ \ \ \ \text{ if } t \text{ mod } k = 1; \text{ else } h_t = h_{t-1}
22+
\end{align}
23+
24+
Sophia uses EMA of gradients $m_t$, only considers positive entries of
25+
the diagonal Hessian and does per-coordinate clipping to the update.
26+
27+
\begin{align}
28+
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1)g_t \\
29+
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{ \max \{h_t, \epsilon \} }, \rho \bigg)
30+
\end{align}
31+
32+
where $\epsilon$ is a very small value to prevent division by $0$.
33+
34+
### Gauss-Newton-Bartlett (GNB) estimator
35+
36+
\begin{align}
37+
\hat{L}(\theta) &= \frac{1}{B} \sum^{B}_{b=1} \ell_{CE} \big( f(\theta, x_b), \hat{y}_b \big) \\
38+
\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta)
39+
\end{align}
40+
41+
where $x_b$ are the inputs,
42+
$B$ is the batch size (number of inputs/tokens),
43+
$\ell_{CE}$ is cross entropy loss, and
44+
$\hat{y}_b$ are sampled from the logits $f(\theta, x_b)$.
45+
46+
Note that this hessian estimate is always positive and therefore we
47+
can replace $\max \{h_t, \epsilon \}$ with $h_t + \epsilon$.
48+
49+
Sophia with Gauss-Newton-Bartlett (GNB) estimator is **Sophia-G**
50+
51+
Here is an [experiment](../transformers/basic/with_sophia.html) that uses Sophia-G to train a transformer.
1152
"""
1253

1354
from typing import Dict, Any, Tuple, Optional
@@ -27,15 +68,15 @@ class Sophia(GenericAdaptiveOptimizer):
2768
"""
2869

2970
def __init__(self, params,
30-
lr: float = 1e-4, betas: Tuple[float, float] = (0.965, 0.99), eps: float = 1e-16,
31-
rho: float = 0.04,
71+
lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-12,
72+
rho: float = 0.03,
3273
weight_decay: WeightDecay = WeightDecay(),
3374
defaults: Optional[Dict[str, Any]] = None):
3475
"""
3576
### Initialize the optimizer
3677
3778
* `params` is the list of parameters
38-
* `lr` is the learning rate $\alpha$
79+
* `lr` is the maximum learning rate $\eta \rho$
3980
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
4081
* `eps` is $\epsilon$
4182
* `pho` is $\rho$
@@ -61,23 +102,46 @@ def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Par
61102

62103
# This is the number of optimizer steps taken on the parameter, $t$
63104
state['step'] = 0
64-
# state['hessian_updates']
65105
# Exponential moving average of gradients, $m_t$
66106
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
67-
# Exponential moving average of Hessian
107+
# Exponential moving average of Hessian diagonal, $h_t$
68108
state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
69109

70110
def update_hessian(self, n_tokens_training_batch):
111+
"""
112+
### Update the EMA of Hessian diagonal $h_t$
113+
114+
* `n_tokens_training_batch` is the number of tokens/inputs in the batch $B$
115+
116+
\begin{align}
117+
\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
118+
h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
119+
\end{align}
120+
"""
121+
122+
# Iterate through parameter groups
71123
for group in self.param_groups:
72-
beta1, beta2 = group['betas']
124+
# $\beta_2$
125+
_, beta2 = group['betas']
126+
# Iterate through parameters
73127
for p in group['params']:
128+
# Skip parameters without gradients
74129
if p.grad is None:
75130
continue
131+
132+
# Get optimizer state
76133
state = self.state[p]
77134

135+
# Initialize state if empty
78136
if len(state) == 0:
79137
self.init_state(state, group, p)
80138

139+
# Update EMA Hessian diagonal
140+
#
141+
# \begin{align}
142+
# \hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
143+
# h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
144+
# \end{align}
81145
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)
82146

83147
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
@@ -88,17 +152,24 @@ def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.T
88152
* `group` stores optimizer attributes of the parameter group
89153
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
90154
* `param` is the parameter tensor $\theta_{t-1}$
155+
156+
We do the following parameter update,
157+
158+
\begin{align}
159+
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg) \\
160+
\theta_{t + 1} &\leftarrow \theta_t - \eta \rho \cdot \operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)
161+
\end{align}
91162
"""
92163

93164
# Calculate weight decay
94165
grad = self.weight_decay(param, grad, group)
95166

96167
# Get $\beta_1$ and $\beta_2$
97168
beta1, beta2 = group['betas']
98-
169+
# Get $\rho$
99170
rho = group['rho']
100171

101-
# Get $m_{t-1}$ and $v_{t-1}$
172+
# Get $m_{t-1}$ and $h_{t}$
102173
m, hessian = state['exp_avg'], state['hessian']
103174

104175
# In-place calculation of $m_t$
@@ -108,9 +179,11 @@ def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.T
108179
# Increment $t$ the number of optimizer steps
109180
state['step'] += 1
110181

111-
# Get learning rate
182+
# Get maximum learning rate $\eta \rho$
112183
lr = group['lr']
113184

114-
ratio = (m.abs() / (rho * hessian + group['eps'])).clamp(None, 1)
185+
# $$\operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)$$
186+
ratio = (m / (rho * hessian + group['eps'])).clamp(-1, 1)
115187

116-
param.data.addcmul_(m.sign(), ratio, value=-lr)
188+
# $$\theta_{t + 1} \leftarrow \theta_t - \eta \rho \cdot \operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)$$
189+
param.data.add_(ratio, alpha=-lr)

labml_nn/transformers/basic/with_sophia.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
1+
"""
2+
---
3+
title: Transformer Auto-Regression Experiment with [Sophia-G optimizer](../../optimizers/sophia.html)
4+
summary: >
5+
This trains a simple transformer model on NLP auto-regression with Sophia-G optimizer.
6+
---
7+
8+
# Transformer Auto-Regression Experiment with [Sophia-G optimizer](../../optimizers/sophia.html)
9+
10+
This trains a simple transformer introduced in [Attention Is All You Need](https://papers.labml.ai/paper/1706.03762)
11+
on an NLP auto-regression task (with Tiny Shakespeare dataset) with [Sophia-G optimizer](../../optimizers/sophia.html).
12+
"""
113
import torch
2-
from labml.configs import option
314

415
from labml import experiment, tracker
516
from labml_helpers.train_valid import BatchIndex
@@ -20,7 +31,7 @@ class Configs(TransformerAutoRegressionConfigs):
2031

2132
def step(self, batch: any, batch_idx: BatchIndex):
2233
"""
23-
### Training or validation step
34+
### Training or validation step with Gauss-Newton-Bartlett (GNB) Hessian diagonal estimator
2435
"""
2536

2637
# Set training/eval mode
@@ -29,15 +40,14 @@ def step(self, batch: any, batch_idx: BatchIndex):
2940
# Move data to the device
3041
data, target = batch[0].to(self.device), batch[1].to(self.device)
3142

43+
# Estimate the Hessian diagonal every $k$ steps
3244
if isinstance(self.optimizer, Sophia) and self.mode.is_train and batch_idx.idx % self.hess_interval == 0:
33-
# Whether to capture model outputs
34-
with self.mode.update(is_log_activations=False):
35-
# Get model outputs.
36-
# It's returning a tuple for states when using RNNs.
37-
# This is not implemented yet. 😜
38-
output, *_ = self.model(data)
45+
# Get model outputs
46+
output, *_ = self.model(data)
3947

48+
# Create a categorical distribution from logits
4049
samp_dist = torch.distributions.Categorical(logits=output)
50+
# Sample $\hat{y}$
4151
y_sample = samp_dist.sample()
4252

4353
# Calculate and log loss
@@ -48,7 +58,12 @@ def step(self, batch: any, batch_idx: BatchIndex):
4858
loss.backward()
4959
# Clip gradients
5060
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
51-
# Update Hessian estimate
61+
# Update EMA Hessian diagonal
62+
#
63+
# \begin{align}
64+
# \hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
65+
# h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
66+
# \end{align}
5267
self.optimizer.update_hessian(data.numel())
5368
# Clear the gradients
5469
self.optimizer.zero_grad()
@@ -95,7 +110,6 @@ def step(self, batch: any, batch_idx: BatchIndex):
95110
tracker.save()
96111

97112

98-
99113
def main():
100114
# Create experiment
101115
experiment.create(name="transformer")
@@ -127,7 +141,7 @@ def main():
127141
'transformer.n_heads': 16,
128142
'transformer.ffn.d_ff': 1024,
129143

130-
# Use [Noam optimizer](../../optimizers/noam.html)
144+
# Use [Sophia optimizer](../../optimizers/sophia.html)
131145
'optimizer.optimizer': 'Sophia',
132146
'optimizer.learning_rate': 3e-4,
133147
'optimizer.rho': 0.03,

0 commit comments

Comments
 (0)