Skip to content

Commit 5156fef

Browse files
authored
fix np.memmap memory leak
nn.memmap doesn't free memory that it accesses. Thus, the entire dataset gets stored in RAM as the dataset has been fully accessed. The simplest workaround on stackoverflow is to just recreate the memmap for each batch. The extra overhead is negligible. https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
1 parent eba36e8 commit 5156fef

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

train.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,13 @@
113113

114114
# poor man's data loader
115115
data_dir = os.path.join('data', dataset)
116-
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
117-
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
118116
def get_batch(split):
119-
data = train_data if split == 'train' else val_data
117+
# We recreate np.memmap every batch to avoid a memory leak, as per
118+
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
119+
if split == 'train':
120+
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
121+
else:
122+
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
120123
ix = torch.randint(len(data) - block_size, (batch_size,))
121124
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
122125
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

0 commit comments

Comments
 (0)