Skip to content

Commit a6a708c

Browse files
authored
Merge branch 'master' into grad_accum
2 parents 978d4fe + d9f4735 commit a6a708c

5 files changed

Lines changed: 28 additions & 11 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.DS_Store
2+
.ipynb_checkpoints/
3+
__pycache__/
4+
*.pyc

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Dependencies:
1919
- `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
2020
- `pip install tiktoken` for OpenAI's fast BPE code <3
2121
- `pip install wandb` for optional logging <3
22-
- `pip install tqdm`
22+
- `pip install tqdm` <3
2323

2424
## quick start
2525

@@ -37,7 +37,7 @@ This creates a `train.bin` and `val.bin` in that data directory. Now it is time
3737
$ python train.py config/train_shakespeare_char.py
3838
```
3939

40-
If you peak inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
40+
If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
4141

4242
```
4343
$ python sample.py --out_dir=out-shakespeare-char
@@ -84,7 +84,7 @@ bot thou the sought bechive in that to doth groan you,
8484
No relving thee post mose the wear
8585
```
8686

87-
Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
87+
Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
8888

8989
Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.
9090

data/openwebtext/prepare.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,16 @@ def process(example):
5454
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
5555
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
5656
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
57+
total_batches = 1024
5758

58-
print(f"writing {filename}...")
5959
idx = 0
60-
for example in tqdm(dset):
61-
arr[idx : idx + example['len']] = example['ids']
62-
idx += example['len']
60+
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
61+
# Batch together samples for faster write
62+
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
63+
arr_batch = np.concatenate(batch['ids'])
64+
# Write into mmap
65+
arr[idx : idx + len(arr_batch)] = arr_batch
66+
idx += len(arr_batch)
6367
arr.flush()
6468

6569
# train.bin is ~17GB, val.bin ~8.5MB

model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ def forward(self, x):
6161
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
6262

6363
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
64-
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
64+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
6565
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
6666
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
6767
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
6868

6969
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
7070
if self.flash:
7171
# efficient attention using Flash Attention CUDA kernels
72-
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
72+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
7373
else:
7474
# manual implementation of attention
7575
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -207,7 +207,8 @@ def crop_block_size(self, block_size):
207207
self.config.block_size = block_size
208208
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
209209
for block in self.transformer.h:
210-
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
210+
if hasattr(block.attn, 'bias'):
211+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
211212

212213
@classmethod
213214
def from_pretrained(cls, model_type, override_args=None):

train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
init_process_group(backend=backend)
8585
ddp_rank = int(os.environ['RANK'])
8686
ddp_local_rank = int(os.environ['LOCAL_RANK'])
87+
ddp_world_size = int(os.environ['WORLD_SIZE'])
8788
device = f'cuda:{ddp_local_rank}'
8889
torch.cuda.set_device(device)
8990
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
@@ -94,6 +95,9 @@
9495
# if not ddp, we are running on a single gpu, and one process
9596
master_process = True
9697
seed_offset = 0
98+
ddp_world_size = 1
99+
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
100+
print(f"tokens per iteration will be: {tokens_per_iter:,}")
97101

98102
if master_process:
99103
os.makedirs(out_dir, exist_ok=True)
@@ -190,6 +194,7 @@ def get_batch(split):
190194
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
191195
if init_from == 'resume':
192196
optimizer.load_state_dict(checkpoint['optimizer'])
197+
checkpoint = None # free up memory
193198

194199
# compile the model
195200
if compile:
@@ -288,6 +293,7 @@ def get_lr(it):
288293
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
289294
with ctx:
290295
logits, loss = model(X, Y)
296+
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
291297
# immediately async prefetch next batch while model is doing the forward pass on the GPU
292298
X, Y = get_batch('train')
293299
# backward pass, with gradient scaling if training in fp16
@@ -307,7 +313,9 @@ def get_lr(it):
307313
dt = t1 - t0
308314
t0 = t1
309315
if iter_num % log_interval == 0 and master_process:
310-
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
316+
# get loss as float. note: this is a CPU-GPU sync point
317+
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
318+
lossf = loss.item() * gradient_accumulation_steps
311319
if local_iter_num >= 5: # let the training loop settle a bit
312320
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
313321
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu

0 commit comments

Comments
 (0)