Skip to content

Commit 53e317d

Browse files
committed
add hugginface integration
1 parent 4e57b56 commit 53e317d

4 files changed

Lines changed: 38 additions & 14 deletions

File tree

scrapegraphai/graphs/abstract_graph.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
"""
1+
"""
22
Module having abstract class for creating all the graphs
33
"""
44
from abc import ABC, abstractmethod
55
from typing import Optional
6-
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
6+
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace
77
from ..helpers import models_tokens
88

99

@@ -46,7 +46,7 @@ def _create_llm(self, llm_config: dict):
4646
# take the model after the last dash
4747
llm_params["model"] = llm_params["model"].split("/")[-1]
4848
try:
49-
self.model_token = models_tokens["openai"][llm_params["model"]]
49+
self.model_token = models_tokens["azure"][llm_params["model"]]
5050
except KeyError:
5151
raise ValueError("Model not supported")
5252
return AzureOpenAI(llm_params)
@@ -59,14 +59,6 @@ def _create_llm(self, llm_config: dict):
5959
return Gemini(llm_params)
6060

6161
elif "ollama" in llm_params["model"]:
62-
"""
63-
Avaiable models:
64-
- llama2
65-
- mistral
66-
- codellama
67-
- dolphin-mixtral
68-
- mistral-openorca
69-
"""
7062
llm_params["model"] = llm_params["model"].split("/")[-1]
7163

7264
# allow user to set model_tokens in config
@@ -79,9 +71,15 @@ def _create_llm(self, llm_config: dict):
7971
raise ValueError("Model not supported")
8072

8173
return Ollama(llm_params)
82-
74+
elif "hugging_face" in llm_params["model"]:
75+
try:
76+
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
77+
except KeyError:
78+
raise ValueError("Model not supported")
79+
return HuggingFace(llm_params)
8380
else:
84-
raise ValueError("Model not supported")
81+
raise ValueError(
82+
"Model provided by the configuration not supported")
8583

8684
@abstractmethod
8785
def _create_graph(self):

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .openai_tts import OpenAITextToSpeech
99
from .gemini import Gemini
1010
from .ollama import Ollama
11+
from .hugging_face import HuggingFace
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Module for implementing the hugginface class
3+
"""
4+
from langchain_community.chat_models.huggingface import ChatHuggingFace
5+
6+
7+
class HuggingFace(ChatHuggingFace):
8+
"""Provides a convenient wrapper for interacting with Hugging Face language models
9+
designed for conversational AI applications.
10+
11+
Args:
12+
llm_config (dict): A configuration dictionary containing:
13+
* api_key (str, optional): Your Hugging Face API key.
14+
* model_name (str): The name of the Hugging Face LLM to load.
15+
* tokenizer_name (str, optional): Name of the corresponding tokenizer.
16+
* device (str, optional): Device for running the model ('cpu' by default).
17+
18+
"""
19+
20+
def __init__(self, llm_config: dict):
21+
"""Initializes the HuggingFace chat model wrapper"""
22+
super().__init__(**llm_config)

scrapegraphai/nodes/rag_node.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from langchain.retrievers import ContextualCompressionRetriever
88
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
99
from langchain_community.document_transformers import EmbeddingsRedundantFilter
10+
from langchain_community.embeddings import HuggingFaceHubEmbeddings
1011
from langchain_community.vectorstores import FAISS
1112
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
12-
from ..models import OpenAI, Ollama, AzureOpenAI
13+
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace
1314
from langchain_community.embeddings import OllamaEmbeddings
1415
from .base_node import BaseNode
1516

@@ -92,6 +93,8 @@ def execute(self, state):
9293
embeddings = AzureOpenAIEmbeddings()
9394
elif isinstance(embedding_model, Ollama):
9495
embeddings = OllamaEmbeddings(model=embedding_model.model)
96+
elif isinstance(embedding_model, Ollama):
97+
embeddings = HuggingFaceHubEmbeddings(embedding_model, HuggingFace)
9598
else:
9699
raise ValueError("Embedding Model missing or not supported")
97100

0 commit comments

Comments
 (0)