@@ -151,32 +151,44 @@ def __init__(
151151 self ,
152152 model_name_or_path : str = None ,
153153 use_fp16 : bool = False ,
154- cache_dir : str = None
154+ cache_dir : str = None ,
155+ device : Union [str , int ] = None
155156 ) -> None :
156157
157158 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path , cache_dir = cache_dir )
158159 self .model = AutoModelForSequenceClassification .from_pretrained (model_name_or_path , cache_dir = cache_dir )
159160
160- if torch .cuda .is_available ():
161- self .device = torch .device ("cuda" )
162- elif torch .backends .mps .is_available ():
163- self .device = torch .device ("mps" )
164- elif is_torch_npu_available ():
165- self .device = torch .device ("npu" )
161+ if device and isinstance (device , str ):
162+ self .device = torch .device (device )
163+ if device == 'cpu' :
164+ use_fp16 = False
166165 else :
167- self .device = torch .device ("cpu" )
168- use_fp16 = False
166+ if torch .cuda .is_available ():
167+ if device is not None :
168+ self .device = torch .device (f"cuda:{ device } " )
169+ else :
170+ self .device = torch .device ("cuda" )
171+ elif torch .backends .mps .is_available ():
172+ self .device = torch .device ("mps" )
173+ elif is_torch_npu_available ():
174+ self .device = torch .device ("npu" )
175+ else :
176+ self .device = torch .device ("cpu" )
177+ use_fp16 = False
169178 if use_fp16 :
170179 self .model .half ()
171180
172181 self .model = self .model .to (self .device )
173182
174183 self .model .eval ()
175184
176- self .num_gpus = torch .cuda .device_count ()
177- if self .num_gpus > 1 :
178- print (f"----------using { self .num_gpus } *GPUs----------" )
179- self .model = torch .nn .DataParallel (self .model )
185+ if device is None :
186+ self .num_gpus = torch .cuda .device_count ()
187+ if self .num_gpus > 1 :
188+ print (f"----------using { self .num_gpus } *GPUs----------" )
189+ self .model = torch .nn .DataParallel (self .model )
190+ else :
191+ self .num_gpus = 1
180192
181193 @torch .no_grad ()
182194 def compute_score (self , sentence_pairs : Union [List [Tuple [str , str ]], Tuple [str , str ]], batch_size : int = 256 ,
@@ -218,7 +230,7 @@ def __init__(
218230 use_fp16 : bool = False ,
219231 use_bf16 : bool = False ,
220232 cache_dir : str = None ,
221- device : int = 0
233+ device : Union [ str , int ] = None
222234 ) -> None :
223235 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ,
224236 cache_dir = cache_dir ,
@@ -231,14 +243,21 @@ def __init__(
231243 self .model_name_or_path = model_name_or_path
232244 self .cache_dir = cache_dir
233245
234- if torch .cuda .is_available ():
235- torch .cuda .set_device (device )
236- self .device = torch .device ('cuda' )
237- elif torch .backends .mps .is_available ():
238- self .device = torch .device ('mps' )
246+ if device and isinstance (device , str ):
247+ self .device = torch .device (device )
239248 else :
240- self .device = torch .device ('cpu' )
241- use_fp16 = False
249+ device = 0 if device is None else device
250+ if torch .cuda .is_available ():
251+ torch .cuda .set_device (device )
252+ self .device = torch .device ("cuda" )
253+ elif torch .backends .mps .is_available ():
254+ self .device = torch .device ("mps" )
255+ elif is_torch_npu_available ():
256+ self .device = torch .device ("npu" )
257+ else :
258+ self .device = torch .device ("cpu" )
259+ use_fp16 = False
260+
242261 if use_fp16 and use_bf16 is False :
243262 self .model .half ()
244263
@@ -311,7 +330,7 @@ def __init__(
311330 use_fp16 : bool = False ,
312331 use_bf16 : bool = False ,
313332 cache_dir : str = None ,
314- device : int = 0
333+ device : Union [ str , int ] = None
315334 ) -> None :
316335 self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ,
317336 cache_dir = cache_dir ,
@@ -329,14 +348,24 @@ def __init__(
329348 self .model_name_or_path = model_name_or_path
330349 self .cache_dir = cache_dir
331350
332- if torch . cuda . is_available ( ):
333- torch . cuda . set_device ( device )
334- self . device = torch . device ( 'cuda ' )
335- elif torch . backends . mps . is_available ():
336- self .device = torch .device ('mps' )
351+ if device and isinstance ( device , str ):
352+ if device == 'cpu' :
353+ warnings . warn ( 'The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA. ' )
354+ device = 'cuda'
355+ self .device = torch .device (device )
337356 else :
338- self .device = torch .device ('cpu' )
339- use_fp16 = False
357+ device = 0 if device is None else device
358+ if torch .cuda .is_available ():
359+ torch .cuda .set_device (device )
360+ self .device = torch .device ("cuda" )
361+ elif torch .backends .mps .is_available ():
362+ self .device = torch .device ("mps" )
363+ elif is_torch_npu_available ():
364+ self .device = torch .device ("npu" )
365+ else :
366+ self .device = torch .device ("cpu" )
367+ use_fp16 = False
368+
340369 if use_fp16 and use_bf16 is False :
341370 self .model .half ()
342371
0 commit comments