web loader support proxy

This commit is contained in:
Yimi81 2025-02-14 07:15:09 +00:00
parent 304aed0f13
commit d3f71930f0
4 changed files with 58 additions and 2 deletions

View File

@ -1853,6 +1853,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
"RAG_WEB_SEARCH_TRUST_ENV",
"rag.web.search.trust_env",
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
)
####################################
# Images

View File

@ -175,6 +175,7 @@ from open_webui.config import (
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_TRUST_ENV,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY,
SEARCHAPI_API_KEY,
@ -558,6 +559,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
app.state.EMBEDDING_FUNCTION = None
app.state.ef = None

View File

@ -1,7 +1,9 @@
import socket
import aiohttp
import asyncio
import urllib.parse
import validators
from typing import Union, Sequence, Iterator
from typing import Union, Sequence, Iterator, Dict
from langchain_community.document_loaders import (
WebBaseLoader,
@ -68,6 +70,45 @@ def resolve_hostname(hostname):
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url, **(self.requests_kwargs | kwargs)
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
@ -96,13 +137,15 @@ def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
trust_env: bool = False,
):
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
return SafeWebBaseLoader(
safe_urls,
web_path=safe_urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
trust_env=trust_env
)

View File

@ -450,6 +450,7 @@ class WebSearchConfig(BaseModel):
exa_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
trust_env: Optional[bool] = None
domain_filter_list: Optional[List[str]] = []
@ -569,6 +570,9 @@ async def update_rag_config(
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
form_data.web.search.trust_env
)
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
form_data.web.search.domain_filter_list
)
@ -621,6 +625,7 @@ async def update_rag_config(
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
@ -1340,6 +1345,7 @@ def process_web_search(
urls,
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
)
docs = loader.load()
save_docs_to_vector_db(