Skip to content

Commit 7fe4a09

Browse files
committed
simplify configure_optimizers by a lot
1 parent 196160b commit 7fe4a09

1 file changed

Lines changed: 17 additions & 49 deletions

File tree

model.py

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -268,60 +268,28 @@ def from_pretrained(cls, model_type, override_args=None):
268268
return model
269269

270270
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
271-
"""
272-
This long function is unfortunately doing something very simple and is being very defensive:
273-
We are separating out all parameters of the model into two buckets: those that will experience
274-
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
275-
We are then returning the PyTorch optimizer object.
276-
"""
277-
278-
# separate out all parameters to those that will and won't experience regularizing weight decay
279-
decay = set()
280-
no_decay = set()
281-
whitelist_weight_modules = (torch.nn.Linear, )
282-
blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
283-
for mn, m in self.named_modules():
284-
for pn, p in m.named_parameters():
285-
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
286-
# random note: because named_modules and named_parameters are recursive
287-
# we will see the same tensors p many many times. but doing it this way
288-
# allows us to know which parent module any tensor p belongs to...
289-
if pn.endswith('bias'):
290-
# all biases will not be decayed
291-
no_decay.add(fpn)
292-
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
293-
# weights of whitelist modules will be weight decayed
294-
decay.add(fpn)
295-
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
296-
# weights of blacklist modules will NOT be weight decayed
297-
no_decay.add(fpn)
298-
299-
# subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
300-
# will appear in the no_decay and decay sets respectively after the above.
301-
# In addition, because named_parameters() doesn't return duplicates, it
302-
# will only return the first occurence, key'd by 'transformer.wte.weight', below.
303-
# so let's manually remove 'lm_head.weight' from decay set. This will include
304-
# this tensor into optimization via transformer.wte.weight only, and not decayed.
305-
decay.remove('lm_head.weight')
306-
307-
# validate that we considered every parameter
271+
# start with all of the candidate parameters
308272
param_dict = {pn: p for pn, p in self.named_parameters()}
309-
inter_params = decay & no_decay
310-
union_params = decay | no_decay
311-
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
312-
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
313-
% (str(param_dict.keys() - union_params), )
314-
315-
# create the pytorch optimizer object
273+
# filter out those that do not require grad
274+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
275+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
276+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
277+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
278+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
316279
optim_groups = [
317-
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
318-
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
280+
{'params': decay_params, 'weight_decay': weight_decay},
281+
{'params': nodecay_params, 'weight_decay': 0.0}
319282
]
320-
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
321-
use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
322-
print(f"using fused AdamW: {use_fused}")
283+
num_decay_params = sum(p.numel() for p in decay_params)
284+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
285+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
286+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
287+
# Create AdamW optimizer and use the fused version if it is available
288+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
289+
use_fused = fused_available and device_type == 'cuda'
323290
extra_args = dict(fused=True) if use_fused else dict()
324291
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
292+
print(f"using fused AdamW: {use_fused}")
325293

326294
return optimizer
327295

0 commit comments

Comments
 (0)