diff --git a/backend/open_webui/retrieval/loaders/tavily.py b/backend/open_webui/retrieval/loaders/tavily.py new file mode 100644 index 000000000..b96396eba --- /dev/null +++ b/backend/open_webui/retrieval/loaders/tavily.py @@ -0,0 +1,98 @@ +import requests +import logging +from typing import Iterator, List, Literal, Union + +from langchain_core.document_loaders import BaseLoader +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +class TavilyLoader(BaseLoader): + """Extract web page content from URLs using Tavily Extract API. + + This is a LangChain document loader that uses Tavily's Extract API to + retrieve content from web pages and return it as Document objects. + + Args: + urls: URL or list of URLs to extract content from. + api_key: The Tavily API key. + extract_depth: Depth of extraction, either "basic" or "advanced". + continue_on_failure: Whether to continue if extraction of a URL fails. + """ + def __init__( + self, + urls: Union[str, List[str]], + api_key: str, + extract_depth: Literal["basic", "advanced"] = "basic", + continue_on_failure: bool = True, + ) -> None: + """Initialize Tavily Extract client. + + Args: + urls: URL or list of URLs to extract content from. + api_key: The Tavily API key. + include_images: Whether to include images in the extraction. + extract_depth: Depth of extraction, either "basic" or "advanced". + advanced extraction retrieves more data, including tables and + embedded content, with higher success but may increase latency. + basic costs 1 credit per 5 successful URL extractions, + advanced costs 2 credits per 5 successful URL extractions. + continue_on_failure: Whether to continue if extraction of a URL fails. + """ + if not urls: + raise ValueError("At least one URL must be provided.") + + self.api_key = api_key + self.urls = urls if isinstance(urls, list) else [urls] + self.extract_depth = extract_depth + self.continue_on_failure = continue_on_failure + self.api_url = "https://api.tavily.com/extract" + + def lazy_load(self) -> Iterator[Document]: + """Extract and yield documents from the URLs using Tavily Extract API.""" + batch_size = 20 + for i in range(0, len(self.urls), batch_size): + batch_urls = self.urls[i:i + batch_size] + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + # Use string for single URL, array for multiple URLs + urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls + payload = { + "urls": urls_param, + "extract_depth": self.extract_depth + } + # Make the API call + response = requests.post( + self.api_url, + headers=headers, + json=payload + ) + response.raise_for_status() + response_data = response.json() + # Process successful results + for result in response_data.get("results", []): + url = result.get("url", "") + content = result.get("raw_content", "") + if not content: + log.warning(f"No content extracted from {url}") + continue + # Add URLs as metadata + metadata = {"source": url} + yield Document( + page_content=content, + metadata=metadata, + ) + for failed in response_data.get("failed_results", []): + url = failed.get("url", "") + error = failed.get("error", "Unknown error") + log.error(f"Failed to extract content from {url}: {error}") + except Exception as e: + if self.continue_on_failure: + log.error(f"Error extracting content from batch {batch_urls}: {e}") + else: + raise e \ No newline at end of file diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index fd94a1a32..cffa1b5aa 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -24,6 +24,7 @@ from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoa from langchain_community.document_loaders.firecrawl import FireCrawlLoader from langchain_community.document_loaders.base import BaseLoader from langchain_core.documents import Document +from open_webui.retrieval.loaders.tavily import TavilyLoader from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( ENABLE_RAG_LOCAL_WEB_FETCH, @@ -31,6 +32,7 @@ from open_webui.config import ( RAG_WEB_LOADER_ENGINE, FIRECRAWL_API_BASE_URL, FIRECRAWL_API_KEY, + TAVILY_API_KEY, ) from open_webui.env import SRC_LOG_LEVELS @@ -113,7 +115,47 @@ def verify_ssl_cert(url: str) -> bool: return False -class SafeFireCrawlLoader(BaseLoader): +class RateLimitMixin: + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + +class URLProcessingMixin: + def _verify_ssl_cert(self, url: str) -> bool: + """Verify SSL certificate for a URL.""" + return verify_ssl_cert(url) + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + +class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): def __init__( self, web_paths, @@ -208,43 +250,120 @@ class SafeFireCrawlLoader(BaseLoader): continue raise e - def _verify_ssl_cert(self, url: str) -> bool: - return verify_ssl_cert(url) - async def _wait_for_rate_limit(self): - """Wait to respect the rate limit if specified.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - await asyncio.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() +class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): + def __init__( + self, + web_paths: Union[str, List[str]], + api_key: str, + extract_depth: Literal["basic", "advanced"] = "basic", + continue_on_failure: bool = True, + requests_per_second: Optional[float] = None, + verify_ssl: bool = True, + trust_env: bool = False, + proxy: Optional[Dict[str, str]] = None, + ): + """Initialize SafeTavilyLoader with rate limiting and SSL verification support. - def _sync_wait_for_rate_limit(self): - """Synchronous version of rate limit wait.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - time.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() + Args: + web_paths: List of URLs/paths to process. + api_key: The Tavily API key. + extract_depth: Depth of extraction ("basic" or "advanced"). + continue_on_failure: Whether to continue if extraction of a URL fails. + requests_per_second: Number of requests per second to limit to. + verify_ssl: If True, verify SSL certificates. + trust_env: If True, use proxy settings from environment variables. + proxy: Optional proxy configuration. + """ + # Initialize proxy configuration if using environment variables + proxy_server = proxy.get("server") if proxy else None + if trust_env and not proxy_server: + env_proxies = urllib.request.getproxies() + env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + if env_proxy_server: + if proxy: + proxy["server"] = env_proxy_server + else: + proxy = {"server": env_proxy_server} + + # Store parameters for creating TavilyLoader instances + self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths] + self.api_key = api_key + self.extract_depth = extract_depth + self.continue_on_failure = continue_on_failure + self.verify_ssl = verify_ssl + self.trust_env = trust_env + self.proxy = proxy + + # Add rate limiting + self.requests_per_second = requests_per_second + self.last_request_time = None - async def _safe_process_url(self, url: str) -> bool: - """Perform safety checks before processing a URL.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - await self._wait_for_rate_limit() - return True - - def _safe_process_url_sync(self, url: str) -> bool: - """Synchronous version of safety checks.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - self._sync_wait_for_rate_limit() - return True + def lazy_load(self) -> Iterator[Document]: + """Load documents with rate limiting support, delegating to TavilyLoader.""" + valid_urls = [] + for url in self.web_paths: + try: + self._safe_process_url_sync(url) + valid_urls.append(url) + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + if not self.continue_on_failure: + raise e + if not valid_urls: + if self.continue_on_failure: + log.warning("No valid URLs to process after SSL verification") + return + raise ValueError("No valid URLs to process after SSL verification") + try: + loader = TavilyLoader( + urls=valid_urls, + api_key=self.api_key, + extract_depth=self.extract_depth, + continue_on_failure=self.continue_on_failure, + ) + yield from loader.lazy_load() + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error extracting content from URLs") + else: + raise e + + async def alazy_load(self) -> AsyncIterator[Document]: + """Async version with rate limiting and SSL verification.""" + valid_urls = [] + for url in self.web_paths: + try: + await self._safe_process_url(url) + valid_urls.append(url) + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + if not self.continue_on_failure: + raise e + + if not valid_urls: + if self.continue_on_failure: + log.warning("No valid URLs to process after SSL verification") + return + raise ValueError("No valid URLs to process after SSL verification") + + try: + loader = TavilyLoader( + urls=valid_urls, + api_key=self.api_key, + extract_depth=self.extract_depth, + continue_on_failure=self.continue_on_failure, + ) + async for document in loader.alazy_load(): + yield document + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading URLs") + else: + raise e -class SafePlaywrightURLLoader(PlaywrightURLLoader): +class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin): """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection. Attributes: @@ -356,40 +475,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader): raise e await browser.close() - def _verify_ssl_cert(self, url: str) -> bool: - return verify_ssl_cert(url) - - async def _wait_for_rate_limit(self): - """Wait to respect the rate limit if specified.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - await asyncio.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() - - def _sync_wait_for_rate_limit(self): - """Synchronous version of rate limit wait.""" - if self.requests_per_second and self.last_request_time: - min_interval = timedelta(seconds=1.0 / self.requests_per_second) - time_since_last = datetime.now() - self.last_request_time - if time_since_last < min_interval: - time.sleep((min_interval - time_since_last).total_seconds()) - self.last_request_time = datetime.now() - - async def _safe_process_url(self, url: str) -> bool: - """Perform safety checks before processing a URL.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - await self._wait_for_rate_limit() - return True - - def _safe_process_url_sync(self, url: str) -> bool: - """Synchronous version of safety checks.""" - if self.verify_ssl and not self._verify_ssl_cert(url): - raise ValueError(f"SSL certificate verification failed for {url}") - self._sync_wait_for_rate_limit() - return True class SafeWebBaseLoader(WebBaseLoader): @@ -499,6 +584,7 @@ RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader) RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader +RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader def get_web_loader( @@ -525,6 +611,9 @@ def get_web_loader( web_loader_args["api_key"] = FIRECRAWL_API_KEY.value web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value + if RAG_WEB_LOADER_ENGINE.value == "tavily": + web_loader_args["api_key"] = TAVILY_API_KEY.value + # Create the appropriate WebLoader based on the configuration WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value] web_loader = WebLoaderClass(**web_loader_args)