We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 1f77d03 + 6716607 commit 3611338Copy full SHA for 3611338
1 file changed
train.py
@@ -259,7 +259,6 @@ def get_lr(iter):
259
break
260
261
# forward backward update, with optional gradient accumulation to simulate larger batch size
262
- optimizer.zero_grad(set_to_none=True)
263
for micro_step in range(gradient_accumulation_steps):
264
X, Y = get_batch('train')
265
if ddp:
@@ -272,6 +271,7 @@ def get_lr(iter):
272
271
logits, loss = model(X, Y)
273
loss.backward()
274
optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
275
276
# timing and logging
277
t1 = time.time()
0 commit comments