web loader support proxy

This commit is contained in:
Yimi81 2025-02-14 07:15:09 +00:00
parent 304aed0f13
commit d3f71930f0
4 changed files with 58 additions and 2 deletions

View File

@ -1853,6 +1853,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
) )
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
"RAG_WEB_SEARCH_TRUST_ENV",
"rag.web.search.trust_env",
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
)
#################################### ####################################
# Images # Images

View File

@ -175,6 +175,7 @@ from open_webui.config import (
RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_TRUST_ENV,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY, JINA_API_KEY,
SEARCHAPI_API_KEY, SEARCHAPI_API_KEY,
@ -558,6 +559,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
app.state.EMBEDDING_FUNCTION = None app.state.EMBEDDING_FUNCTION = None
app.state.ef = None app.state.ef = None

View File

@ -1,7 +1,9 @@
import socket import socket
import aiohttp
import asyncio
import urllib.parse import urllib.parse
import validators import validators
from typing import Union, Sequence, Iterator from typing import Union, Sequence, Iterator, Dict
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
@ -68,6 +70,45 @@ def resolve_hostname(hostname):
class SafeWebBaseLoader(WebBaseLoader): class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs.""" """WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url, **(self.requests_kwargs | kwargs)
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling.""" """Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths: for path in self.web_paths:
@ -96,13 +137,15 @@ def get_web_loader(
urls: Union[str, Sequence[str]], urls: Union[str, Sequence[str]],
verify_ssl: bool = True, verify_ssl: bool = True,
requests_per_second: int = 2, requests_per_second: int = 2,
trust_env: bool = False,
): ):
# Check if the URLs are valid # Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
return SafeWebBaseLoader( return SafeWebBaseLoader(
safe_urls, web_path=safe_urls,
verify_ssl=verify_ssl, verify_ssl=verify_ssl,
requests_per_second=requests_per_second, requests_per_second=requests_per_second,
continue_on_failure=True, continue_on_failure=True,
trust_env=trust_env
) )

View File

@ -450,6 +450,7 @@ class WebSearchConfig(BaseModel):
exa_api_key: Optional[str] = None exa_api_key: Optional[str] = None
result_count: Optional[int] = None result_count: Optional[int] = None
concurrent_requests: Optional[int] = None concurrent_requests: Optional[int] = None
trust_env: Optional[bool] = None
domain_filter_list: Optional[List[str]] = [] domain_filter_list: Optional[List[str]] = []
@ -569,6 +570,9 @@ async def update_rag_config(
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests form_data.web.search.concurrent_requests
) )
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
form_data.web.search.trust_env
)
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = ( request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
form_data.web.search.domain_filter_list form_data.web.search.domain_filter_list
) )
@ -621,6 +625,7 @@ async def update_rag_config(
"exa_api_key": request.app.state.config.EXA_API_KEY, "exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
}, },
}, },
@ -1340,6 +1345,7 @@ def process_web_search(
urls, urls,
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
) )
docs = loader.load() docs = loader.load()
save_docs_to_vector_db( save_docs_to_vector_db(