Skip to content

Commit 7c02294

Browse files
committed
sophia exp
1 parent f45ca5e commit 7c02294

3 files changed

Lines changed: 162 additions & 13 deletions

File tree

labml_nn/optimizers/configs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class OptimizerConfigs(BaseConfigs):
6767
# Model embedding size for Noam optimizer
6868
d_model: int
6969

70+
rho: float
71+
7072
def __init__(self):
7173
super().__init__(_primary='optimizer')
7274

@@ -137,6 +139,14 @@ def _noam_optimizer(c: OptimizerConfigs):
137139
d_model=c.d_model)
138140

139141

142+
@option(OptimizerConfigs.optimizer, 'Sophia')
143+
def _sophia_optimizer(c: OptimizerConfigs):
144+
from labml_nn.optimizers.sophia import Sophia
145+
return Sophia(c.parameters,
146+
lr=c.learning_rate, betas=c.betas, eps=c.eps,
147+
weight_decay=c.weight_decay_obj, rho=c.rho)
148+
149+
140150
@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
141151
def _noam_optimizer(c: OptimizerConfigs):
142152
from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay

labml_nn/optimizers/sophia.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ class Sophia(GenericAdaptiveOptimizer):
2929
def __init__(self, params,
3030
lr: float = 1e-4, betas: Tuple[float, float] = (0.965, 0.99), eps: float = 1e-16,
3131
rho: float = 0.04,
32-
training_batch_tokens: int = None,
3332
weight_decay: WeightDecay = WeightDecay(),
34-
optimized_update: bool = True,
3533
defaults: Optional[Dict[str, Any]] = None):
3634
"""
3735
### Initialize the optimizer
@@ -42,21 +40,15 @@ def __init__(self, params,
4240
* `eps` is $\epsilon$
4341
* `pho` is $\rho$
4442
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
45-
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
46-
by doing it after adding $\epsilon$
4743
* `defaults` is a dictionary of default for group values.
4844
This is useful when you want to extend the class `Adam`.
4945
"""
50-
if training_batch_tokens is None:
51-
raise RuntimeError('Please set the number of tokens per training batch.')
52-
5346
defaults = {} if defaults is None else defaults
5447
defaults.update(weight_decay.defaults())
55-
defaults.update(dict(rho=rho, training_batch_tokens=training_batch_tokens))
48+
defaults.update(dict(rho=rho))
5649
super().__init__(params, defaults, lr, betas, eps)
5750

5851
self.weight_decay = weight_decay
59-
self.optimized_update = optimized_update
6052

6153
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
6254
"""
@@ -75,7 +67,7 @@ def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Par
7567
# Exponential moving average of Hessian
7668
state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
7769

78-
def update_hessian(self, batch_size):
70+
def update_hessian(self, n_tokens_training_batch):
7971
for group in self.param_groups:
8072
beta1, beta2 = group['betas']
8173
for p in group['params']:
@@ -86,7 +78,7 @@ def update_hessian(self, batch_size):
8678
if len(state) == 0:
8779
self.init_state(state, group, p)
8880

89-
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * batch_size)
81+
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)
9082

9183
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
9284
"""
@@ -107,7 +99,7 @@ def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.T
10799
rho = group['rho']
108100

109101
# Get $m_{t-1}$ and $v_{t-1}$
110-
m, hessian = state['exp_avg'], state['hessain']
102+
m, hessian = state['exp_avg'], state['hessian']
111103

112104
# In-place calculation of $m_t$
113105
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
@@ -119,6 +111,6 @@ def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.T
119111
# Get learning rate
120112
lr = group['lr']
121113

122-
ratio = (m.abs() / (rho * hessian + group['training_batch_tokens'] * group['eps'])).clamp(None, 1)
114+
ratio = (m.abs() / (rho * hessian + group['eps'])).clamp(None, 1)
123115

124116
param.data.addcmul_(m.sign(), ratio, value=-lr)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import torch
2+
from labml.configs import option
3+
4+
from labml import experiment, tracker
5+
from labml_helpers.train_valid import BatchIndex
6+
from labml_nn.optimizers.sophia import Sophia
7+
from labml_nn.transformers.basic.autoregressive_experiment import Configs as TransformerAutoRegressionConfigs
8+
9+
10+
class Configs(TransformerAutoRegressionConfigs):
11+
"""
12+
## Configurations
13+
14+
This inherits from [`Configs`](autoregressive_experiment.html)
15+
"""
16+
17+
hess_interval: int = 10
18+
19+
optimizer: Sophia
20+
21+
def step(self, batch: any, batch_idx: BatchIndex):
22+
"""
23+
### Training or validation step
24+
"""
25+
26+
# Set training/eval mode
27+
self.model.train(self.mode.is_train)
28+
29+
# Move data to the device
30+
data, target = batch[0].to(self.device), batch[1].to(self.device)
31+
32+
if isinstance(self.optimizer, Sophia) and self.mode.is_train and batch_idx.idx % self.hess_interval == 0:
33+
# Whether to capture model outputs
34+
with self.mode.update(is_log_activations=False):
35+
# Get model outputs.
36+
# It's returning a tuple for states when using RNNs.
37+
# This is not implemented yet. 😜
38+
output, *_ = self.model(data)
39+
40+
samp_dist = torch.distributions.Categorical(logits=output)
41+
y_sample = samp_dist.sample()
42+
43+
# Calculate and log loss
44+
loss = self.loss_func(output, y_sample)
45+
tracker.add("loss.hess.", loss)
46+
47+
# Calculate gradients
48+
loss.backward()
49+
# Clip gradients
50+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
51+
# Update Hessian estimate
52+
self.optimizer.update_hessian(data.numel())
53+
# Clear the gradients
54+
self.optimizer.zero_grad()
55+
else:
56+
# Move data to the device
57+
data, target = batch[0].to(self.device), batch[1].to(self.device)
58+
59+
# Update global step (number of tokens processed) when in training mode
60+
if self.mode.is_train:
61+
tracker.add_global_step(data.shape[0] * data.shape[1])
62+
63+
# Whether to capture model outputs
64+
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
65+
# Get model outputs.
66+
# It's returning a tuple for states when using RNNs.
67+
# This is not implemented yet. 😜
68+
output, *_ = self.model(data)
69+
70+
# Calculate and log loss
71+
loss = self.loss_func(output, target)
72+
tracker.add("loss.", loss)
73+
74+
# Calculate and log accuracy
75+
self.accuracy(output, target)
76+
self.accuracy.track()
77+
78+
self.other_metrics(output, target)
79+
80+
# Train the model
81+
if self.mode.is_train:
82+
# Calculate gradients
83+
loss.backward()
84+
# Clip gradients
85+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
86+
# Take optimizer step
87+
self.optimizer.step()
88+
# Log the model parameters and gradients on last batch of every epoch
89+
if batch_idx.is_last and self.is_log_model_params_grads:
90+
tracker.add('model', self.model)
91+
# Clear the gradients
92+
self.optimizer.zero_grad()
93+
94+
# Save the tracked metrics
95+
tracker.save()
96+
97+
98+
99+
def main():
100+
# Create experiment
101+
experiment.create(name="transformer")
102+
# Create configs
103+
conf = Configs()
104+
# Override configurations
105+
experiment.configs(conf, {
106+
# Use character level tokenizer
107+
'tokenizer': 'character',
108+
# Prompt separator is blank
109+
'prompt_separator': '',
110+
# Starting prompt for sampling
111+
'prompt': 'It is ',
112+
# Use Tiny Shakespeare dataset
113+
'text': 'tiny_shakespeare',
114+
115+
# Use a context size of $256$
116+
'seq_len': 512,
117+
# Train for 32 epochs
118+
'epochs': 32,
119+
# Batch size $32$
120+
'batch_size': 16,
121+
# Switch between training and validation for $10$ times
122+
# per epoch
123+
'inner_iterations': 10,
124+
125+
# Model size
126+
'd_model': 256,
127+
'transformer.n_heads': 16,
128+
'transformer.ffn.d_ff': 1024,
129+
130+
# Use [Noam optimizer](../../optimizers/noam.html)
131+
'optimizer.optimizer': 'Sophia',
132+
'optimizer.learning_rate': 3e-4,
133+
'optimizer.rho': 0.03,
134+
})
135+
136+
# Set models for saving and loading
137+
experiment.add_pytorch_models({'model': conf.model})
138+
139+
# Start the experiment
140+
with experiment.start():
141+
# Run training
142+
conf.run()
143+
144+
145+
#
146+
if __name__ == '__main__':
147+
main()

0 commit comments

Comments
 (0)