diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index adfdcfec8..d7e3e5973 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1853,6 +1853,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), ) +RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( + "RAG_WEB_SEARCH_TRUST_ENV", + "rag.web.search.trust_env", + os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False), +) #################################### # Images diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 71142978c..e2e61ec3a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -175,6 +175,7 @@ from open_webui.config import ( RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + RAG_WEB_SEARCH_TRUST_ENV, RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, JINA_API_KEY, SEARCHAPI_API_KEY, @@ -558,6 +559,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV app.state.EMBEDDING_FUNCTION = None app.state.ef = None diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 7fdf2f3c6..caba85ddb 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -5,6 +5,7 @@ import urllib.parse import validators from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union + from langchain_community.document_loaders import ( WebBaseLoader, ) @@ -70,6 +71,45 @@ def resolve_hostname(hostname): class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" + def __init__(self, trust_env: bool = False, *args, **kwargs): + """Initialize SafeWebBaseLoader + Args: + trust_env (bool, optional): set to True if using proxy to make web requests, for example + using http(s)_proxy environment variables. Defaults to False. + """ + super().__init__(*args, **kwargs) + self.trust_env = trust_env + + async def _fetch( + self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5 + ) -> str: + async with aiohttp.ClientSession(trust_env=self.trust_env) as session: + for i in range(retries): + try: + kwargs: Dict = dict( + headers=self.session.headers, + cookies=self.session.cookies.get_dict(), + ) + if not self.session.verify: + kwargs["ssl"] = False + + async with session.get( + url, **(self.requests_kwargs | kwargs) + ) as response: + if self.raise_for_status: + response.raise_for_status() + return await response.text() + except aiohttp.ClientConnectionError as e: + if i == retries - 1: + raise + else: + log.warning( + f"Error fetching {url} with attempt " + f"{i + 1}/{retries}: {e}. Retrying..." + ) + await asyncio.sleep(cooldown * backoff**i) + raise ValueError("retry count exceeded") + def _unpack_fetch_results( self, results: Any, urls: List[str], parser: Union[str, None] = None ) -> List[Any]: @@ -95,6 +135,7 @@ class SafeWebBaseLoader(WebBaseLoader): results = await self.fetch_all(urls) return self._unpack_fetch_results(results, urls, parser=parser) + def lazy_load(self) -> Iterator[Document]: """Lazy load text from the url(s) in web_path with error handling.""" for path in self.web_paths: @@ -143,13 +184,15 @@ def get_web_loader( urls: Union[str, Sequence[str]], verify_ssl: bool = True, requests_per_second: int = 2, + trust_env: bool = False, ): # Check if the URLs are valid safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) return SafeWebBaseLoader( - safe_urls, + web_path=safe_urls, verify_ssl=verify_ssl, requests_per_second=requests_per_second, continue_on_failure=True, + trust_env=trust_env ) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index c8ca1d85b..47e6253c8 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -451,6 +451,7 @@ class WebSearchConfig(BaseModel): exa_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None + trust_env: Optional[bool] = None domain_filter_list: Optional[List[str]] = [] @@ -570,6 +571,9 @@ async def update_rag_config( request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) + request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = ( + form_data.web.search.trust_env + ) request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = ( form_data.web.search.domain_filter_list ) @@ -622,6 +626,7 @@ async def update_rag_config( "exa_api_key": request.app.state.config.EXA_API_KEY, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, }, }, @@ -1341,6 +1346,7 @@ async def process_web_search( urls, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, ) docs = await loader.aload() await run_in_threadpool(