File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 113113
114114# poor man's data loader
115115data_dir = os .path .join ('data' , dataset )
116- train_data = np .memmap (os .path .join (data_dir , 'train.bin' ), dtype = np .uint16 , mode = 'r' )
117- val_data = np .memmap (os .path .join (data_dir , 'val.bin' ), dtype = np .uint16 , mode = 'r' )
118116def get_batch (split ):
119- data = train_data if split == 'train' else val_data
117+ # We recreate np.memmap every batch to avoid a memory leak, as per
118+ # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
119+ if split == 'train' :
120+ data = np .memmap (os .path .join (data_dir , 'train.bin' ), dtype = np .uint16 , mode = 'r' )
121+ else :
122+ data = np .memmap (os .path .join (data_dir , 'val.bin' ), dtype = np .uint16 , mode = 'r' )
120123 ix = torch .randint (len (data ) - block_size , (batch_size ,))
121124 x = torch .stack ([torch .from_numpy ((data [i :i + block_size ]).astype (np .int64 )) for i in ix ])
122125 y = torch .stack ([torch .from_numpy ((data [i + 1 :i + 1 + block_size ]).astype (np .int64 )) for i in ix ])
You can’t perform that action at this time.
0 commit comments