-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathnvidia.py
More file actions
32 lines (25 loc) · 1.11 KB
/
nvidia.py
File metadata and controls
32 lines (25 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
NVIDIA Module
"""
class Nvidia:
"""
A wrapper for the ChatNVIDIA class that provides default configuration
and could be extended with additional methods if needed.
Note: This class uses __new__ instead of __init__ because langchain_nvidia_ai_endpoints
is an optional dependency. We cannot inherit from ChatNVIDIA at class definition time
since the module may not be installed. The __new__ method allows us to lazily import
and return a ChatNVIDIA instance only when Nvidia() is instantiated.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __new__(cls, **llm_config):
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError(
"""The langchain_nvidia_ai_endpoints module is not installed.
Please install it using `pip install langchain-nvidia-ai-endpoints`."""
)
if "api_key" in llm_config:
llm_config["nvidia_api_key"] = llm_config.pop("api_key")
return ChatNVIDIA(**llm_config)