Skip to content

Commit dde9eee

Browse files
authored
Merge pull request #60 from VinciGit00/huggingface_integration
add hugginface integration (embeddings, models ...)
2 parents 7d521ef + dc149e6 commit dde9eee

4 files changed

Lines changed: 42 additions & 18 deletions

File tree

scrapegraphai/graphs/abstract_graph.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
"""
1+
"""
22
Module having abstract class for creating all the graphs
33
"""
44
from abc import ABC, abstractmethod
55
from typing import Optional
6-
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
6+
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace
77
from ..helpers import models_tokens
88

99
class AbstractGraph(ABC):
@@ -48,7 +48,7 @@ def _create_llm(self, llm_config: dict):
4848
# take the model after the last dash
4949
llm_params["model"] = llm_params["model"].split("/")[-1]
5050
try:
51-
self.model_token = models_tokens["openai"][llm_params["model"]]
51+
self.model_token = models_tokens["azure"][llm_params["model"]]
5252
except KeyError:
5353
raise ValueError("Model not supported")
5454
return AzureOpenAI(llm_params)
@@ -61,14 +61,6 @@ def _create_llm(self, llm_config: dict):
6161
return Gemini(llm_params)
6262

6363
elif "ollama" in llm_params["model"]:
64-
"""
65-
Avaiable models:
66-
- llama2
67-
- mistral
68-
- codellama
69-
- dolphin-mixtral
70-
- mistral-openorca
71-
"""
7264
llm_params["model"] = llm_params["model"].split("/")[-1]
7365

7466
# allow user to set model_tokens in config
@@ -81,9 +73,15 @@ def _create_llm(self, llm_config: dict):
8173
raise ValueError("Model not supported")
8274

8375
return Ollama(llm_params)
84-
76+
elif "hugging_face" in llm_params["model"]:
77+
try:
78+
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
79+
except KeyError:
80+
raise ValueError("Model not supported")
81+
return HuggingFace(llm_params)
8582
else:
86-
raise ValueError("Model not supported")
83+
raise ValueError(
84+
"Model provided by the configuration not supported")
8785

8886
def get_execution_info(self):
8987
"""

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .openai_tts import OpenAITextToSpeech
99
from .gemini import Gemini
1010
from .ollama import Ollama
11+
from .hugging_face import HuggingFace
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Module for implementing the hugginface class
3+
"""
4+
from langchain_community.chat_models.huggingface import ChatHuggingFace
5+
6+
7+
class HuggingFace(ChatHuggingFace):
8+
"""Provides a convenient wrapper for interacting with Hugging Face language models
9+
designed for conversational AI applications.
10+
11+
Args:
12+
llm_config (dict): A configuration dictionary containing:
13+
* api_key (str, optional): Your Hugging Face API key.
14+
* model_name (str): The name of the Hugging Face LLM to load.
15+
* tokenizer_name (str, optional): Name of the corresponding tokenizer.
16+
* device (str, optional): Device for running the model ('cpu' by default).
17+
18+
"""
19+
20+
def __init__(self, llm_config: dict):
21+
"""Initializes the HuggingFace chat model wrapper"""
22+
super().__init__(**llm_config)

scrapegraphai/nodes/rag_node.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from langchain.retrievers import ContextualCompressionRetriever
88
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
99
from langchain_community.document_transformers import EmbeddingsRedundantFilter
10+
from langchain_community.embeddings import HuggingFaceHubEmbeddings
1011
from langchain_community.vectorstores import FAISS
1112
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
12-
from ..models import OpenAI, Ollama, AzureOpenAI
13+
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace
1314
from langchain_community.embeddings import OllamaEmbeddings
1415
from .base_node import BaseNode
1516

@@ -26,11 +27,11 @@ class RAGNode(BaseNode):
2627
node_type (str): The type of the node, set to "node" indicating a standard operational node.
2728
2829
Args:
29-
node_name (str, optional): The unique identifier name for the node.
30+
node_name (str, optional): The unique identifier name for the node.
3031
Defaults to "ParseHTMLNode".
3132
3233
Methods:
33-
execute(state): Parses the HTML document contained within the state using
34+
execute(state): Parses the HTML document contained within the state using
3435
the specified tags, if provided, and updates the state with the parsed content.
3536
"""
3637

@@ -44,7 +45,7 @@ def __init__(self, input: str, output: List[str], node_config: dict, node_name:
4445

4546
def execute(self, state):
4647
"""
47-
Executes the node's logic to implement RAG (Retrieval-Augmented Generation)
48+
Executes the node's logic to implement RAG (Retrieval-Augmented Generation)
4849
The method updates the state with relevant chunks of the document.
4950
5051
Args:
@@ -54,7 +55,7 @@ def execute(self, state):
5455
dict: The updated state containing the 'relevant_chunks' key with the relevant chunks.
5556
5657
Raises:
57-
KeyError: If 'document' is not found in the state, indicating that the necessary
58+
KeyError: If 'document' is not found in the state, indicating that the necessary
5859
information for parsing is missing.
5960
"""
6061

@@ -92,6 +93,8 @@ def execute(self, state):
9293
embeddings = AzureOpenAIEmbeddings()
9394
elif isinstance(embedding_model, Ollama):
9495
embeddings = OllamaEmbeddings(model=embedding_model.model)
96+
elif isinstance(embedding_model, HuggingFace):
97+
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
9598
else:
9699
raise ValueError("Embedding Model missing or not supported")
97100

0 commit comments

Comments
 (0)