Skip to content

Commit f68ac22

Browse files
authored
Merge pull request karpathy#428 from kjslag/memmap-memory-leak
fix np.memmap memory leak
2 parents eba36e8 + 5156fef commit f68ac22

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

train.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,13 @@
113113

114114
# poor man's data loader
115115
data_dir = os.path.join('data', dataset)
116-
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
117-
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
118116
def get_batch(split):
119-
data = train_data if split == 'train' else val_data
117+
# We recreate np.memmap every batch to avoid a memory leak, as per
118+
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
119+
if split == 'train':
120+
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
121+
else:
122+
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
120123
ix = torch.randint(len(data) - block_size, (batch_size,))
121124
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
122125
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

0 commit comments

Comments
 (0)