Skip to content

Commit ab21d6c

Browse files
committed
bugfix we have to call the raw_model's estimate_mfu ty @jprobichaud for original PR
1 parent f83dd03 commit ab21d6c

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def get_lr(it):
240240
X, Y = get_batch('train') # fetch the very first batch
241241
t0 = time.time()
242242
local_iter_num = 0 # number of iterations in the lifetime of this process
243+
raw_model = model.module if ddp else model # unwrap DDP container if needed
243244
running_mfu = -1.0
244245
while True:
245246

@@ -262,7 +263,6 @@ def get_lr(it):
262263
})
263264
if losses['val'] < best_val_loss or always_save_checkpoint:
264265
best_val_loss = losses['val']
265-
raw_model = model.module if ddp else model
266266
if iter_num > 0:
267267
checkpoint = {
268268
'model': raw_model.state_dict(),
@@ -309,7 +309,7 @@ def get_lr(it):
309309
if iter_num % log_interval == 0 and master_process:
310310
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
311311
if local_iter_num >= 5: # let the training loop settle a bit
312-
mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
312+
mfu = raw_model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
313313
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
314314
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
315315
iter_num += 1

0 commit comments

Comments
 (0)