Skip to content

Commit 88e68af

Browse files
committed
add_device_set
1 parent b0f1d73 commit 88e68af

1 file changed

Lines changed: 58 additions & 29 deletions

File tree

FlagEmbedding/flag_reranker.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -151,32 +151,44 @@ def __init__(
151151
self,
152152
model_name_or_path: str = None,
153153
use_fp16: bool = False,
154-
cache_dir: str = None
154+
cache_dir: str = None,
155+
device: Union[str, int] = None
155156
) -> None:
156157

157158
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
158159
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
159160

160-
if torch.cuda.is_available():
161-
self.device = torch.device("cuda")
162-
elif torch.backends.mps.is_available():
163-
self.device = torch.device("mps")
164-
elif is_torch_npu_available():
165-
self.device = torch.device("npu")
161+
if device and isinstance(device, str):
162+
self.device = torch.device(device)
163+
if device == 'cpu':
164+
use_fp16 = False
166165
else:
167-
self.device = torch.device("cpu")
168-
use_fp16 = False
166+
if torch.cuda.is_available():
167+
if device is not None:
168+
self.device = torch.device(f"cuda:{device}")
169+
else:
170+
self.device = torch.device("cuda")
171+
elif torch.backends.mps.is_available():
172+
self.device = torch.device("mps")
173+
elif is_torch_npu_available():
174+
self.device = torch.device("npu")
175+
else:
176+
self.device = torch.device("cpu")
177+
use_fp16 = False
169178
if use_fp16:
170179
self.model.half()
171180

172181
self.model = self.model.to(self.device)
173182

174183
self.model.eval()
175184

176-
self.num_gpus = torch.cuda.device_count()
177-
if self.num_gpus > 1:
178-
print(f"----------using {self.num_gpus}*GPUs----------")
179-
self.model = torch.nn.DataParallel(self.model)
185+
if device is None:
186+
self.num_gpus = torch.cuda.device_count()
187+
if self.num_gpus > 1:
188+
print(f"----------using {self.num_gpus}*GPUs----------")
189+
self.model = torch.nn.DataParallel(self.model)
190+
else:
191+
self.num_gpus = 1
180192

181193
@torch.no_grad()
182194
def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 256,
@@ -218,7 +230,7 @@ def __init__(
218230
use_fp16: bool = False,
219231
use_bf16: bool = False,
220232
cache_dir: str = None,
221-
device: int = 0
233+
device: Union[str, int] = None
222234
) -> None:
223235
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
224236
cache_dir=cache_dir,
@@ -231,14 +243,21 @@ def __init__(
231243
self.model_name_or_path = model_name_or_path
232244
self.cache_dir = cache_dir
233245

234-
if torch.cuda.is_available():
235-
torch.cuda.set_device(device)
236-
self.device = torch.device('cuda')
237-
elif torch.backends.mps.is_available():
238-
self.device = torch.device('mps')
246+
if device and isinstance(device, str):
247+
self.device = torch.device(device)
239248
else:
240-
self.device = torch.device('cpu')
241-
use_fp16 = False
249+
device = 0 if device is None else device
250+
if torch.cuda.is_available():
251+
torch.cuda.set_device(device)
252+
self.device = torch.device("cuda")
253+
elif torch.backends.mps.is_available():
254+
self.device = torch.device("mps")
255+
elif is_torch_npu_available():
256+
self.device = torch.device("npu")
257+
else:
258+
self.device = torch.device("cpu")
259+
use_fp16 = False
260+
242261
if use_fp16 and use_bf16 is False:
243262
self.model.half()
244263

@@ -311,7 +330,7 @@ def __init__(
311330
use_fp16: bool = False,
312331
use_bf16: bool = False,
313332
cache_dir: str = None,
314-
device: int = 0
333+
device: Union[str, int] = None
315334
) -> None:
316335
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
317336
cache_dir=cache_dir,
@@ -329,14 +348,24 @@ def __init__(
329348
self.model_name_or_path = model_name_or_path
330349
self.cache_dir = cache_dir
331350

332-
if torch.cuda.is_available():
333-
torch.cuda.set_device(device)
334-
self.device = torch.device('cuda')
335-
elif torch.backends.mps.is_available():
336-
self.device = torch.device('mps')
351+
if device and isinstance(device, str):
352+
if device == 'cpu':
353+
warnings.warn('The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA.')
354+
device = 'cuda'
355+
self.device = torch.device(device)
337356
else:
338-
self.device = torch.device('cpu')
339-
use_fp16 = False
357+
device = 0 if device is None else device
358+
if torch.cuda.is_available():
359+
torch.cuda.set_device(device)
360+
self.device = torch.device("cuda")
361+
elif torch.backends.mps.is_available():
362+
self.device = torch.device("mps")
363+
elif is_torch_npu_available():
364+
self.device = torch.device("npu")
365+
else:
366+
self.device = torch.device("cpu")
367+
use_fp16 = False
368+
340369
if use_fp16 and use_bf16 is False:
341370
self.model.half()
342371

0 commit comments

Comments
 (0)