Skip to content

Commit 269b356

Browse files
authored
add Exception('Batchsize must be greater than num_gpus.')
1 parent e16c8c8 commit 269b356

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

utils/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def scatter(self, inputs, kwargs, device_ids):
195195
devices = ['cuda:' + str(x) for x in range(self.num_gpus)]
196196
splits = inputs[0].shape[0] // self.num_gpus
197197

198+
if splits == 0:
199+
raise Exception('Batchsize must be greater than num_gpus.')
200+
198201
return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True),
199202
inputs[1][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True))
200203
for device_idx in range(len(devices))], \

0 commit comments

Comments
 (0)