@@ -268,60 +268,28 @@ def from_pretrained(cls, model_type, override_args=None):
268268 return model
269269
270270 def configure_optimizers (self , weight_decay , learning_rate , betas , device_type ):
271- """
272- This long function is unfortunately doing something very simple and is being very defensive:
273- We are separating out all parameters of the model into two buckets: those that will experience
274- weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
275- We are then returning the PyTorch optimizer object.
276- """
277-
278- # separate out all parameters to those that will and won't experience regularizing weight decay
279- decay = set ()
280- no_decay = set ()
281- whitelist_weight_modules = (torch .nn .Linear , )
282- blacklist_weight_modules = (torch .nn .LayerNorm , LayerNorm , torch .nn .Embedding )
283- for mn , m in self .named_modules ():
284- for pn , p in m .named_parameters ():
285- fpn = '%s.%s' % (mn , pn ) if mn else pn # full param name
286- # random note: because named_modules and named_parameters are recursive
287- # we will see the same tensors p many many times. but doing it this way
288- # allows us to know which parent module any tensor p belongs to...
289- if pn .endswith ('bias' ):
290- # all biases will not be decayed
291- no_decay .add (fpn )
292- elif pn .endswith ('weight' ) and isinstance (m , whitelist_weight_modules ):
293- # weights of whitelist modules will be weight decayed
294- decay .add (fpn )
295- elif pn .endswith ('weight' ) and isinstance (m , blacklist_weight_modules ):
296- # weights of blacklist modules will NOT be weight decayed
297- no_decay .add (fpn )
298-
299- # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
300- # will appear in the no_decay and decay sets respectively after the above.
301- # In addition, because named_parameters() doesn't return duplicates, it
302- # will only return the first occurence, key'd by 'transformer.wte.weight', below.
303- # so let's manually remove 'lm_head.weight' from decay set. This will include
304- # this tensor into optimization via transformer.wte.weight only, and not decayed.
305- decay .remove ('lm_head.weight' )
306-
307- # validate that we considered every parameter
271+ # start with all of the candidate parameters
308272 param_dict = {pn : p for pn , p in self .named_parameters ()}
309- inter_params = decay & no_decay
310- union_params = decay | no_decay
311- assert len (inter_params ) == 0 , "parameters %s made it into both decay/no_decay sets!" % (str (inter_params ), )
312- assert len (param_dict .keys () - union_params ) == 0 , "parameters %s were not separated into either decay/no_decay set!" \
313- % (str (param_dict .keys () - union_params ), )
314-
315- # create the pytorch optimizer object
273+ # filter out those that do not require grad
274+ param_dict = {pn : p for pn , p in param_dict .items () if p .requires_grad }
275+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
276+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
277+ decay_params = [p for n , p in param_dict .items () if p .dim () >= 2 ]
278+ nodecay_params = [p for n , p in param_dict .items () if p .dim () < 2 ]
316279 optim_groups = [
317- {" params" : [ param_dict [ pn ] for pn in sorted ( list ( decay ))], " weight_decay" : weight_decay },
318- {" params" : [ param_dict [ pn ] for pn in sorted ( list ( no_decay ))], " weight_decay" : 0.0 },
280+ {' params' : decay_params , ' weight_decay' : weight_decay },
281+ {' params' : nodecay_params , ' weight_decay' : 0.0 }
319282 ]
320- # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
321- use_fused = (device_type == 'cuda' ) and ('fused' in inspect .signature (torch .optim .AdamW ).parameters )
322- print (f"using fused AdamW: { use_fused } " )
283+ num_decay_params = sum (p .numel () for p in decay_params )
284+ num_nodecay_params = sum (p .numel () for p in nodecay_params )
285+ print (f"num decayed parameter tensors: { len (decay_params )} , with { num_decay_params :,} parameters" )
286+ print (f"num non-decayed parameter tensors: { len (nodecay_params )} , with { num_nodecay_params :,} parameters" )
287+ # Create AdamW optimizer and use the fused version if it is available
288+ fused_available = 'fused' in inspect .signature (torch .optim .AdamW ).parameters
289+ use_fused = fused_available and device_type == 'cuda'
323290 extra_args = dict (fused = True ) if use_fused else dict ()
324291 optimizer = torch .optim .AdamW (optim_groups , lr = learning_rate , betas = betas , ** extra_args )
292+ print (f"using fused AdamW: { use_fused } " )
325293
326294 return optimizer
327295
0 commit comments