File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -240,6 +240,7 @@ def get_lr(it):
240240X , Y = get_batch ('train' ) # fetch the very first batch
241241t0 = time .time ()
242242local_iter_num = 0 # number of iterations in the lifetime of this process
243+ raw_model = model .module if ddp else model # unwrap DDP container if needed
243244running_mfu = - 1.0
244245while True :
245246
@@ -262,7 +263,6 @@ def get_lr(it):
262263 })
263264 if losses ['val' ] < best_val_loss or always_save_checkpoint :
264265 best_val_loss = losses ['val' ]
265- raw_model = model .module if ddp else model
266266 if iter_num > 0 :
267267 checkpoint = {
268268 'model' : raw_model .state_dict (),
@@ -309,7 +309,7 @@ def get_lr(it):
309309 if iter_num % log_interval == 0 and master_process :
310310 lossf = loss .item () # loss as float. note: this is a CPU-GPU sync point
311311 if local_iter_num >= 5 : # let the training loop settle a bit
312- mfu = model .estimate_mfu (batch_size * world_size * gradient_accumulation_steps , dt )
312+ mfu = raw_model .estimate_mfu (batch_size * world_size * gradient_accumulation_steps , dt )
313313 running_mfu = mfu if running_mfu == - 1.0 else 0.9 * running_mfu + 0.1 * mfu
314314 print (f"iter { iter_num } : loss { lossf :.4f} , time { dt * 1000 :.2f} ms, mfu { running_mfu * 100 :.2f} %" )
315315 iter_num += 1
You can’t perform that action at this time.
0 commit comments