Merge pull request #11325 from luke3butler/dev

feat: Add `tavily` as a `RAG_WEB_LOADER_ENGINE` option via extract API
This commit is contained in:
Timothy Jaeryang Baek 2025-03-07 20:11:39 -04:00 committed by GitHub
commit 3ab917cdd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 254 additions and 67 deletions

View File

@ -0,0 +1,98 @@
import requests
import logging
from typing import Iterator, List, Literal, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class TavilyLoader(BaseLoader):
"""Extract web page content from URLs using Tavily Extract API.
This is a LangChain document loader that uses Tavily's Extract API to
retrieve content from web pages and return it as Document objects.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
extract_depth: Depth of extraction, either "basic" or "advanced".
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
def __init__(
self,
urls: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
) -> None:
"""Initialize Tavily Extract client.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
include_images: Whether to include images in the extraction.
extract_depth: Depth of extraction, either "basic" or "advanced".
advanced extraction retrieves more data, including tables and
embedded content, with higher success but may increase latency.
basic costs 1 credit per 5 successful URL extractions,
advanced costs 2 credits per 5 successful URL extractions.
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
if not urls:
raise ValueError("At least one URL must be provided.")
self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract"
def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API."""
batch_size = 20
for i in range(0, len(self.urls), batch_size):
batch_urls = self.urls[i:i + batch_size]
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
# Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {
"urls": urls_param,
"extract_depth": self.extract_depth
}
# Make the API call
response = requests.post(
self.api_url,
headers=headers,
json=payload
)
response.raise_for_status()
response_data = response.json()
# Process successful results
for result in response_data.get("results", []):
url = result.get("url", "")
content = result.get("raw_content", "")
if not content:
log.warning(f"No content extracted from {url}")
continue
# Add URLs as metadata
metadata = {"source": url}
yield Document(
page_content=content,
metadata=metadata,
)
for failed in response_data.get("failed_results", []):
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
log.error(f"Failed to extract content from {url}: {error}")
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}")
else:
raise e

View File

@ -24,6 +24,7 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa
from langchain_community.document_loaders.firecrawl import FireCrawlLoader from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ( from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
@ -31,6 +32,7 @@ from open_webui.config import (
RAG_WEB_LOADER_ENGINE, RAG_WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL, FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY, FIRECRAWL_API_KEY,
TAVILY_API_KEY,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -113,7 +115,47 @@ def verify_ssl_cert(url: str) -> bool:
return False return False
class SafeFireCrawlLoader(BaseLoader): class RateLimitMixin:
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
class URLProcessingMixin:
def _verify_ssl_cert(self, url: str) -> bool:
"""Verify SSL certificate for a URL."""
return verify_ssl_cert(url)
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__( def __init__(
self, self,
web_paths, web_paths,
@ -208,43 +250,120 @@ class SafeFireCrawlLoader(BaseLoader):
continue continue
raise e raise e
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self): class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
"""Wait to respect the rate limit if specified.""" def __init__(
if self.requests_per_second and self.last_request_time: self,
min_interval = timedelta(seconds=1.0 / self.requests_per_second) web_paths: Union[str, List[str]],
time_since_last = datetime.now() - self.last_request_time api_key: str,
if time_since_last < min_interval: extract_depth: Literal["basic", "advanced"] = "basic",
await asyncio.sleep((min_interval - time_since_last).total_seconds()) continue_on_failure: bool = True,
self.last_request_time = datetime.now() requests_per_second: Optional[float] = None,
verify_ssl: bool = True,
trust_env: bool = False,
proxy: Optional[Dict[str, str]] = None,
):
"""Initialize SafeTavilyLoader with rate limiting and SSL verification support.
def _sync_wait_for_rate_limit(self): Args:
"""Synchronous version of rate limit wait.""" web_paths: List of URLs/paths to process.
if self.requests_per_second and self.last_request_time: api_key: The Tavily API key.
min_interval = timedelta(seconds=1.0 / self.requests_per_second) extract_depth: Depth of extraction ("basic" or "advanced").
time_since_last = datetime.now() - self.last_request_time continue_on_failure: Whether to continue if extraction of a URL fails.
if time_since_last < min_interval: requests_per_second: Number of requests per second to limit to.
time.sleep((min_interval - time_since_last).total_seconds()) verify_ssl: If True, verify SSL certificates.
self.last_request_time = datetime.now() trust_env: If True, use proxy settings from environment variables.
proxy: Optional proxy configuration.
"""
# Initialize proxy configuration if using environment variables
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# Store parameters for creating TavilyLoader instances
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
self.api_key = api_key
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.verify_ssl = verify_ssl
self.trust_env = trust_env
self.proxy = proxy
# Add rate limiting
self.requests_per_second = requests_per_second
self.last_request_time = None
async def _safe_process_url(self, url: str) -> bool: def lazy_load(self) -> Iterator[Document]:
"""Perform safety checks before processing a URL.""" """Load documents with rate limiting support, delegating to TavilyLoader."""
if self.verify_ssl and not self._verify_ssl_cert(url): valid_urls = []
raise ValueError(f"SSL certificate verification failed for {url}") for url in self.web_paths:
await self._wait_for_rate_limit() try:
return True self._safe_process_url_sync(url)
valid_urls.append(url)
def _safe_process_url_sync(self, url: str) -> bool: except Exception as e:
"""Synchronous version of safety checks.""" log.warning(f"SSL verification failed for {url}: {str(e)}")
if self.verify_ssl and not self._verify_ssl_cert(url): if not self.continue_on_failure:
raise ValueError(f"SSL certificate verification failed for {url}") raise e
self._sync_wait_for_rate_limit() if not valid_urls:
return True if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error extracting content from URLs")
else:
raise e
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async version with rate limiting and SSL verification."""
valid_urls = []
for url in self.web_paths:
try:
await self._safe_process_url(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
async for document in loader.alazy_load():
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading URLs")
else:
raise e
class SafePlaywrightURLLoader(PlaywrightURLLoader): class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection. """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes: Attributes:
@ -356,40 +475,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
raise e raise e
await browser.close() await browser.close()
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeWebBaseLoader(WebBaseLoader): class SafeWebBaseLoader(WebBaseLoader):
@ -499,6 +584,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
def get_web_loader( def get_web_loader(
@ -525,6 +611,9 @@ def get_web_loader(
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
if RAG_WEB_LOADER_ENGINE.value == "tavily":
web_loader_args["api_key"] = TAVILY_API_KEY.value
# Create the appropriate WebLoader based on the configuration # Create the appropriate WebLoader based on the configuration
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value] WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
web_loader = WebLoaderClass(**web_loader_args) web_loader = WebLoaderClass(**web_loader_args)