Skip to content

Commit 3611338

Browse files
authored
Merge pull request karpathy#71 from cchan/patch-1
Zero-grad more aggressively to save memory
2 parents 1f77d03 + 6716607 commit 3611338

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def get_lr(iter):
259259
break
260260

261261
# forward backward update, with optional gradient accumulation to simulate larger batch size
262-
optimizer.zero_grad(set_to_none=True)
263262
for micro_step in range(gradient_accumulation_steps):
264263
X, Y = get_batch('train')
265264
if ddp:
@@ -272,6 +271,7 @@ def get_lr(iter):
272271
logits, loss = model(X, Y)
273272
loss.backward()
274273
optimizer.step()
274+
optimizer.zero_grad(set_to_none=True)
275275

276276
# timing and logging
277277
t1 = time.time()

0 commit comments

Comments
 (0)