Skip to content

Commit 978d4fe

Browse files
committed
Fix for gradient_accumulation_steps training slow
1 parent a82b33b commit 978d4fe

3 files changed

Lines changed: 5 additions & 3 deletions

File tree

config/train_gpt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
1111
batch_size = 12
1212
block_size = 1024
13-
gradient_accumulation_steps = 5
13+
gradient_accumulation_steps = 5 * 8
1414

1515
# this makes total number of tokens be 300B
1616
max_iters = 600000

config/train_shakespeare_char.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
wandb_run_name = 'mini-gpt'
1515

1616
dataset = 'shakespeare_char'
17+
gradient_accumulation_steps = 1
1718
batch_size = 64
1819
block_size = 256 # context of up to 256 previous characters
1920

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
wandb_run_name = 'gpt2' # 'run' + str(time.time())
4646
# data
4747
dataset = 'openwebtext'
48-
gradient_accumulation_steps = 5 # used to simulate larger batch sizes
48+
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
4949
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
5050
block_size = 1024
5151
# model
@@ -88,11 +88,12 @@
8888
torch.cuda.set_device(device)
8989
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
9090
seed_offset = ddp_rank # each process gets a different seed
91+
assert gradient_accumulation_steps % torch.cuda.device_count() == 0
92+
gradient_accumulation_steps //= torch.cuda.device_count()
9193
else:
9294
# if not ddp, we are running on a single gpu, and one process
9395
master_process = True
9496
seed_offset = 0
95-
gradient_accumulation_steps *= 8 # simulate 8 gpus
9697

9798
if master_process:
9899
os.makedirs(out_dir, exist_ok=True)

0 commit comments

Comments
 (0)