Skip to content

Commit dfdd41a

Browse files
authored
Merge pull request #1039 from MrAliHasan/feat/openai-batch-api-support
2 parents 9fb5f7c + 9d4eba1 commit dfdd41a

6 files changed

Lines changed: 1192 additions & 0 deletions

File tree

scrapegraphai/graphs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .smart_scraper_lite_graph import SmartScraperLiteGraph
2424
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
2525
from .smart_scraper_multi_graph import SmartScraperMultiGraph
26+
from .smart_scraper_multi_batch_graph import SmartScraperMultiBatchGraph
2627
from .smart_scraper_multi_lite_graph import SmartScraperMultiLiteGraph
2728
from .speech_graph import SpeechGraph
2829
from .xml_scraper_graph import XMLScraperGraph
@@ -45,6 +46,7 @@
4546
"SmartScraperGraph",
4647
"SmartScraperLiteGraph",
4748
"SmartScraperMultiGraph",
49+
"SmartScraperMultiBatchGraph",
4850
"SmartScraperMultiLiteGraph",
4951
"SmartScraperMultiConcatGraph",
5052
# Search-related graphs
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
SmartScraperMultiBatchGraph Module
3+
4+
A scraping pipeline that uses the OpenAI Batch API for LLM calls,
5+
providing 50% cost savings compared to real-time API calls.
6+
"""
7+
8+
import asyncio
9+
from copy import deepcopy
10+
from typing import Dict, List, Optional, Type
11+
12+
from pydantic import BaseModel
13+
14+
from ..nodes import FetchNode, GraphIteratorNode, ParseNode
15+
from ..nodes.batch_generate_answer_node import BatchGenerateAnswerNode
16+
from ..nodes.merge_answers_node import MergeAnswersNode
17+
from ..utils.copy import safe_deepcopy
18+
from .abstract_graph import AbstractGraph
19+
from .base_graph import BaseGraph
20+
from .smart_scraper_graph import SmartScraperGraph
21+
22+
23+
class _FetchParseOnlyGraph(AbstractGraph):
24+
"""Internal graph that only fetches and parses a URL (no LLM generation).
25+
26+
This is used to separate the fetch/parse phase from the LLM generation
27+
phase, allowing all LLM calls to be batched together.
28+
"""
29+
30+
def __init__(
31+
self,
32+
prompt: str,
33+
source: str,
34+
config: dict,
35+
schema: Optional[Type[BaseModel]] = None,
36+
):
37+
super().__init__(prompt, config, source, schema)
38+
self.input_key = "url" if source.startswith("http") else "local_dir"
39+
40+
def _create_graph(self) -> BaseGraph:
41+
fetch_node = FetchNode(
42+
input="url | local_dir",
43+
output=["doc"],
44+
node_config={
45+
"llm_model": self.llm_model,
46+
"force": self.config.get("force", False),
47+
"cut": self.config.get("cut", True),
48+
"loader_kwargs": self.config.get("loader_kwargs", {}),
49+
"browser_base": self.config.get("browser_base"),
50+
"scrape_do": self.config.get("scrape_do"),
51+
"storage_state": self.config.get("storage_state"),
52+
},
53+
)
54+
parse_node = ParseNode(
55+
input="doc",
56+
output=["parsed_doc"],
57+
node_config={
58+
"llm_model": self.llm_model,
59+
"chunk_size": self.model_token,
60+
},
61+
)
62+
63+
return BaseGraph(
64+
nodes=[fetch_node, parse_node],
65+
edges=[(fetch_node, parse_node)],
66+
entry_point=fetch_node,
67+
graph_name=self.__class__.__name__,
68+
)
69+
70+
def run(self) -> str:
71+
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
72+
self.final_state, self.execution_info = self.graph.execute(inputs)
73+
return self.final_state.get("parsed_doc", "")
74+
75+
76+
class SmartScraperMultiBatchGraph(AbstractGraph):
77+
"""A scraping pipeline that uses OpenAI Batch API for cost savings.
78+
79+
Similar to SmartScraperMultiGraph, but instead of making individual
80+
LLM calls per URL, it:
81+
1. Fetches and parses all URLs concurrently (Phase 1)
82+
2. Collects all prompts and submits them as a single OpenAI Batch (Phase 2)
83+
3. Polls for batch completion (Phase 3)
84+
4. Merges all results into a final answer (Phase 4)
85+
86+
This provides ~50% cost savings on OpenAI API calls at the expense
87+
of higher latency (up to 24 hours for batch completion).
88+
89+
Attributes:
90+
prompt (str): The user prompt for scraping.
91+
source (List[str]): List of URLs to scrape.
92+
config (dict): Configuration including 'llm' and optional 'batch_api' settings.
93+
schema (Optional[BaseModel]): Optional Pydantic schema for structured output.
94+
95+
Config options under 'batch_api':
96+
poll_interval (int): Seconds between batch status checks (default: 30).
97+
max_wait_time (int): Maximum wait time in seconds (default: 86400 = 24h).
98+
model (str): Override model for batch requests (optional).
99+
temperature (float): Temperature for batch requests (default: 0.0).
100+
101+
Example:
102+
>>> graph = SmartScraperMultiBatchGraph(
103+
... prompt="Extract the main topic and key points",
104+
... source=[
105+
... "https://example.com/page1",
106+
... "https://example.com/page2",
107+
... ],
108+
... config={
109+
... "llm": {"model": "openai/gpt-4o-mini"},
110+
... "batch_api": {
111+
... "poll_interval": 30,
112+
... "max_wait_time": 3600,
113+
... },
114+
... }
115+
... )
116+
>>> result = graph.run()
117+
"""
118+
119+
def __init__(
120+
self,
121+
prompt: str,
122+
source: List[str],
123+
config: dict,
124+
schema: Optional[Type[BaseModel]] = None,
125+
):
126+
self.copy_config = safe_deepcopy(config)
127+
self.copy_schema = deepcopy(schema)
128+
self.batch_config = config.get("batch_api", {})
129+
130+
# Validate that the model is OpenAI-based
131+
model_str = config.get("llm", {}).get("model", "")
132+
if "/" in model_str:
133+
provider = model_str.split("/")[0]
134+
else:
135+
provider = ""
136+
if provider and provider != "openai":
137+
raise ValueError(
138+
f"SmartScraperMultiBatchGraph only supports OpenAI models. "
139+
f"Got provider '{provider}'. "
140+
f"Use SmartScraperMultiGraph for other providers."
141+
)
142+
143+
super().__init__(prompt, config, source, schema)
144+
145+
def _create_graph(self) -> BaseGraph:
146+
"""Creates the graph of nodes for the batch scraping pipeline.
147+
148+
The graph has two phases:
149+
1. GraphIteratorNode runs _FetchParseOnlyGraph per URL (concurrent)
150+
2. BatchGenerateAnswerNode submits all prompts via Batch API
151+
3. MergeAnswersNode combines the results
152+
153+
Returns:
154+
BaseGraph: A graph instance representing the batch scraping workflow.
155+
"""
156+
# Phase 1: Fetch and parse all URLs concurrently
157+
graph_iterator_node = GraphIteratorNode(
158+
input="user_prompt & urls",
159+
output=["parsed_docs"],
160+
node_config={
161+
"graph_instance": _FetchParseOnlyGraph,
162+
"scraper_config": self.copy_config,
163+
},
164+
schema=self.copy_schema,
165+
)
166+
167+
# Phase 2: Submit all prompts to OpenAI Batch API
168+
batch_generate_node = BatchGenerateAnswerNode(
169+
input="user_prompt & parsed_docs",
170+
output=["results"],
171+
node_config={
172+
"llm_model": self.llm_model,
173+
"schema": self.copy_schema,
174+
"batch_config": self.batch_config,
175+
},
176+
)
177+
178+
# Phase 3: Merge all results
179+
merge_answers_node = MergeAnswersNode(
180+
input="user_prompt & results",
181+
output=["answer"],
182+
node_config={
183+
"llm_model": self.llm_model,
184+
"schema": self.copy_schema,
185+
},
186+
)
187+
188+
return BaseGraph(
189+
nodes=[
190+
graph_iterator_node,
191+
batch_generate_node,
192+
merge_answers_node,
193+
],
194+
edges=[
195+
(graph_iterator_node, batch_generate_node),
196+
(batch_generate_node, merge_answers_node),
197+
],
198+
entry_point=graph_iterator_node,
199+
graph_name=self.__class__.__name__,
200+
)
201+
202+
def run(self) -> str:
203+
"""Executes the full batch scraping pipeline.
204+
205+
This will:
206+
1. Fetch and parse all URLs concurrently
207+
2. Submit all LLM prompts as an OpenAI Batch
208+
3. Poll until the batch completes (may take minutes to hours)
209+
4. Merge results into a final answer
210+
211+
Returns:
212+
str: The merged answer from all scraped URLs.
213+
"""
214+
inputs = {"user_prompt": self.prompt, "urls": self.source}
215+
self.final_state, self.execution_info = self.graph.execute(inputs)
216+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/nodes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from .base_node import BaseNode
6+
from .batch_generate_answer_node import BatchGenerateAnswerNode
67
from .concat_answers_node import ConcatAnswersNode
78
from .conditional_node import ConditionalNode
89
from .description_node import DescriptionNode
@@ -53,6 +54,7 @@
5354
"DescriptionNode",
5455
"ReasoningNode",
5556
# Generation nodes
57+
"BatchGenerateAnswerNode",
5658
"GenerateAnswerNode",
5759
"GenerateAnswerNodeKLevel",
5860
"GenerateAnswerCSVNode",

0 commit comments

Comments
 (0)