|
| 1 | +""" |
| 2 | +--- |
| 3 | +title: Sophia Optimizer |
| 4 | +summary: A simple PyTorch implementation/tutorial of Sophia optimizer |
| 5 | +--- |
| 6 | +
|
| 7 | +# Sophia Optimizer |
| 8 | +
|
| 9 | +This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper |
| 10 | + [Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://papers.labml.ai/paper/2305.14342). |
| 11 | +""" |
| 12 | + |
| 13 | +from typing import Dict, Any, Tuple, Optional |
| 14 | + |
| 15 | +import torch |
| 16 | +from torch import nn |
| 17 | + |
| 18 | +from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay |
| 19 | + |
| 20 | + |
| 21 | +class Sophia(GenericAdaptiveOptimizer): |
| 22 | + """ |
| 23 | + ## Sophia-G Optimizer |
| 24 | +
|
| 25 | + We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html) |
| 26 | + to implement the Sophia optimizer. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, params, |
| 30 | + lr: float = 1e-4, betas: Tuple[float, float] = (0.965, 0.99), eps: float = 1e-16, |
| 31 | + rho: float = 0.04, |
| 32 | + training_batch_tokens: int = None, |
| 33 | + weight_decay: WeightDecay = WeightDecay(), |
| 34 | + optimized_update: bool = True, |
| 35 | + defaults: Optional[Dict[str, Any]] = None): |
| 36 | + """ |
| 37 | + ### Initialize the optimizer |
| 38 | +
|
| 39 | + * `params` is the list of parameters |
| 40 | + * `lr` is the learning rate $\alpha$ |
| 41 | + * `betas` is a tuple of ($\beta_1$, $\beta_2$) |
| 42 | + * `eps` is $\epsilon$ |
| 43 | + * `pho` is $\rho$ |
| 44 | + * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html) |
| 45 | + * `optimized_update` is a flag whether to optimize the bias correction of the second moment |
| 46 | + by doing it after adding $\epsilon$ |
| 47 | + * `defaults` is a dictionary of default for group values. |
| 48 | + This is useful when you want to extend the class `Adam`. |
| 49 | + """ |
| 50 | + if training_batch_tokens is None: |
| 51 | + raise RuntimeError('Please set the number of tokens per training batch.') |
| 52 | + |
| 53 | + defaults = {} if defaults is None else defaults |
| 54 | + defaults.update(weight_decay.defaults()) |
| 55 | + defaults.update(dict(rho=rho, training_batch_tokens=training_batch_tokens)) |
| 56 | + super().__init__(params, defaults, lr, betas, eps) |
| 57 | + |
| 58 | + self.weight_decay = weight_decay |
| 59 | + self.optimized_update = optimized_update |
| 60 | + |
| 61 | + def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter): |
| 62 | + """ |
| 63 | + ### Initialize a parameter state |
| 64 | +
|
| 65 | + * `state` is the optimizer state of the parameter (tensor) |
| 66 | + * `group` stores optimizer attributes of the parameter group |
| 67 | + * `param` is the parameter tensor $\theta_{t-1}$ |
| 68 | + """ |
| 69 | + |
| 70 | + # This is the number of optimizer steps taken on the parameter, $t$ |
| 71 | + state['step'] = 0 |
| 72 | + # state['hessian_updates'] |
| 73 | + # Exponential moving average of gradients, $m_t$ |
| 74 | + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) |
| 75 | + # Exponential moving average of Hessian |
| 76 | + state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format) |
| 77 | + |
| 78 | + def update_hessian(self, batch_size): |
| 79 | + for group in self.param_groups: |
| 80 | + beta1, beta2 = group['betas'] |
| 81 | + for p in group['params']: |
| 82 | + if p.grad is None: |
| 83 | + continue |
| 84 | + state = self.state[p] |
| 85 | + |
| 86 | + if len(state) == 0: |
| 87 | + self.init_state(state, group, p) |
| 88 | + |
| 89 | + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * batch_size) |
| 90 | + |
| 91 | + def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): |
| 92 | + """ |
| 93 | + ### Take an update step for a given parameter tensor |
| 94 | +
|
| 95 | + * `state` is the optimizer state of the parameter (tensor) |
| 96 | + * `group` stores optimizer attributes of the parameter group |
| 97 | + * `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$ |
| 98 | + * `param` is the parameter tensor $\theta_{t-1}$ |
| 99 | + """ |
| 100 | + |
| 101 | + # Calculate weight decay |
| 102 | + grad = self.weight_decay(param, grad, group) |
| 103 | + |
| 104 | + # Get $\beta_1$ and $\beta_2$ |
| 105 | + beta1, beta2 = group['betas'] |
| 106 | + |
| 107 | + rho = group['rho'] |
| 108 | + |
| 109 | + # Get $m_{t-1}$ and $v_{t-1}$ |
| 110 | + m, hessian = state['exp_avg'], state['hessain'] |
| 111 | + |
| 112 | + # In-place calculation of $m_t$ |
| 113 | + # $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$ |
| 114 | + m.mul_(beta1).add_(grad, alpha=1 - beta1) |
| 115 | + |
| 116 | + # Increment $t$ the number of optimizer steps |
| 117 | + state['step'] += 1 |
| 118 | + |
| 119 | + # Get learning rate |
| 120 | + lr = group['lr'] |
| 121 | + |
| 122 | + ratio = (m.abs() / (rho * hessian + group['training_batch_tokens'] * group['eps'])).clamp(None, 1) |
| 123 | + |
| 124 | + param.data.addcmul_(m.sign(), ratio, value=-lr) |
0 commit comments