Skip to content

Commit e58f0cf

Browse files
committed
oops i should not be needing or multiplying by world_size to calculate mfu
1 parent 8b1e432 commit e58f0cf

1 file changed

Lines changed: 1 addition & 3 deletions

File tree

train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,12 @@
8484
init_process_group(backend=backend)
8585
ddp_rank = int(os.environ['RANK'])
8686
ddp_local_rank = int(os.environ['LOCAL_RANK'])
87-
world_size = int(os.environ['WORLD_SIZE']) # total number of training processes
8887
device = f'cuda:{ddp_local_rank}'
8988
torch.cuda.set_device(device)
9089
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
9190
seed_offset = ddp_rank # each process gets a different seed
9291
else:
9392
# if not ddp, we are running on a single gpu, and one process
94-
world_size = 1
9593
master_process = True
9694
seed_offset = 0
9795

@@ -309,7 +307,7 @@ def get_lr(it):
309307
if iter_num % log_interval == 0 and master_process:
310308
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
311309
if local_iter_num >= 5: # let the training loop settle a bit
312-
mfu = raw_model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
310+
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
313311
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
314312
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
315313
iter_num += 1

0 commit comments

Comments
 (0)