Skip to content

Commit 4eb7a96

Browse files
authored
Merge pull request karpathy#305 from okuvshynov/fix_osx_dataload
nanogpt: fix multiprocessing in load_dataset on os x
2 parents 41d7014 + 542ac51 commit 4eb7a96

1 file changed

Lines changed: 54 additions & 53 deletions

File tree

data/openwebtext/prepare.py

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,64 +16,65 @@
1616
# it is better than 1 usually though
1717
num_proc_load_dataset = num_proc
1818

19-
# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
20-
dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
19+
if __name__ == '__main__':
20+
# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
21+
dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
2122

22-
# owt by default only contains the 'train' split, so create a test split
23-
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
24-
split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
23+
# owt by default only contains the 'train' split, so create a test split
24+
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
25+
split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
2526

26-
# this results in:
27-
# >>> split_dataset
28-
# DatasetDict({
29-
# train: Dataset({
30-
# features: ['text'],
31-
# num_rows: 8009762
32-
# })
33-
# val: Dataset({
34-
# features: ['text'],
35-
# num_rows: 4007
36-
# })
37-
# })
27+
# this results in:
28+
# >>> split_dataset
29+
# DatasetDict({
30+
# train: Dataset({
31+
# features: ['text'],
32+
# num_rows: 8009762
33+
# })
34+
# val: Dataset({
35+
# features: ['text'],
36+
# num_rows: 4007
37+
# })
38+
# })
3839

39-
# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
40-
enc = tiktoken.get_encoding("gpt2")
41-
def process(example):
42-
ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
43-
ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
44-
# note: I think eot should be prepended not appended... hmm. it's called "eot" though...
45-
out = {'ids': ids, 'len': len(ids)}
46-
return out
40+
# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
41+
enc = tiktoken.get_encoding("gpt2")
42+
def process(example):
43+
ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
44+
ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
45+
# note: I think eot should be prepended not appended... hmm. it's called "eot" though...
46+
out = {'ids': ids, 'len': len(ids)}
47+
return out
4748

48-
# tokenize the dataset
49-
tokenized = split_dataset.map(
50-
process,
51-
remove_columns=['text'],
52-
desc="tokenizing the splits",
53-
num_proc=num_proc,
54-
)
49+
# tokenize the dataset
50+
tokenized = split_dataset.map(
51+
process,
52+
remove_columns=['text'],
53+
desc="tokenizing the splits",
54+
num_proc=num_proc,
55+
)
5556

56-
# concatenate all the ids in each dataset into one large file we can use for training
57-
for split, dset in tokenized.items():
58-
arr_len = np.sum(dset['len'], dtype=np.uint64)
59-
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
60-
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
61-
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
62-
total_batches = 1024
57+
# concatenate all the ids in each dataset into one large file we can use for training
58+
for split, dset in tokenized.items():
59+
arr_len = np.sum(dset['len'], dtype=np.uint64)
60+
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
61+
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
62+
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
63+
total_batches = 1024
6364

64-
idx = 0
65-
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
66-
# Batch together samples for faster write
67-
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
68-
arr_batch = np.concatenate(batch['ids'])
69-
# Write into mmap
70-
arr[idx : idx + len(arr_batch)] = arr_batch
71-
idx += len(arr_batch)
72-
arr.flush()
65+
idx = 0
66+
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
67+
# Batch together samples for faster write
68+
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
69+
arr_batch = np.concatenate(batch['ids'])
70+
# Write into mmap
71+
arr[idx : idx + len(arr_batch)] = arr_batch
72+
idx += len(arr_batch)
73+
arr.flush()
7374

74-
# train.bin is ~17GB, val.bin ~8.5MB
75-
# train has ~9B tokens (9,035,582,198)
76-
# val has ~4M tokens (4,434,897)
75+
# train.bin is ~17GB, val.bin ~8.5MB
76+
# train has ~9B tokens (9,035,582,198)
77+
# val has ~4M tokens (4,434,897)
7778

78-
# to read the bin files later, e.g. with numpy:
79-
# m = np.memmap('train.bin', dtype=np.uint16, mode='r')
79+
# to read the bin files later, e.g. with numpy:
80+
# m = np.memmap('train.bin', dtype=np.uint16, mode='r')

0 commit comments

Comments
 (0)