From 4e8b3906821a9a10f4fd0038373291dff41b65cf Mon Sep 17 00:00:00 2001 From: Rory <16675082+roryeckel@users.noreply.github.com> Date: Tue, 28 Jan 2025 23:03:15 -0600 Subject: [PATCH] Add RAG_WEB_LOADER + Playwright mode + improve stability of search --- backend/open_webui/config.py | 5 + backend/open_webui/main.py | 2 + backend/open_webui/retrieval/web/main.py | 4 + backend/open_webui/retrieval/web/utils.py | 179 +++++++++++++++++++--- backend/open_webui/routers/retrieval.py | 21 ++- backend/open_webui/utils/middleware.py | 27 ++-- backend/requirements.txt | 2 +- backend/start.sh | 9 ++ backend/start_windows.bat | 9 ++ pyproject.toml | 1 + 10 files changed, 220 insertions(+), 39 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index c37b831de..3cec6edd7 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1712,6 +1712,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), ) +RAG_WEB_LOADER = PersistentConfig( + "RAG_WEB_LOADER", + "rag.web.loader", + os.environ.get("RAG_WEB_LOADER", "safe_web") +) #################################### # Images diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 00270aabc..985624d81 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -129,6 +129,7 @@ from open_webui.config import ( AUDIO_TTS_VOICE, AUDIO_TTS_AZURE_SPEECH_REGION, AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + RAG_WEB_LOADER, WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, @@ -526,6 +527,7 @@ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_K app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +app.state.config.RAG_WEB_LOADER = RAG_WEB_LOADER app.state.EMBEDDING_FUNCTION = None app.state.ef = None diff --git a/backend/open_webui/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py index 1af8a70aa..28a749e7d 100644 --- a/backend/open_webui/retrieval/web/main.py +++ b/backend/open_webui/retrieval/web/main.py @@ -1,3 +1,5 @@ +import validators + from typing import Optional from urllib.parse import urlparse @@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list): filtered_results = [] for result in results: url = result.get("url") or result.get("link", "") + if not validators.url(url): + continue domain = urlparse(url).netloc if any(domain.endswith(filtered_domain) for filtered_domain in filter_list): filtered_results.append(result) diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index a322bbbfc..bdc626749 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -1,16 +1,21 @@ +import asyncio +from datetime import datetime, time, timedelta import socket +import ssl import urllib.parse +import certifi import validators -from typing import Union, Sequence, Iterator +from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator from langchain_community.document_loaders import ( WebBaseLoader, + PlaywrightURLLoader ) from langchain_core.documents import Document from open_webui.constants import ERROR_MESSAGES -from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH +from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, RAG_WEB_LOADER from open_webui.env import SRC_LOG_LEVELS import logging @@ -42,6 +47,15 @@ def validate_url(url: Union[str, Sequence[str]]): else: return False +def safe_validate_urls(url: Sequence[str]) -> Sequence[str]: + valid_urls = [] + for u in url: + try: + if validate_url(u): + valid_urls.append(u) + except ValueError: + continue + return valid_urls def resolve_hostname(hostname): # Get address information @@ -53,6 +67,131 @@ def resolve_hostname(hostname): return ipv4_addresses, ipv6_addresses +def extract_metadata(soup, url): + metadata = { + "source": url + } + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get( + "content", "No description found." + ) + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + return metadata + +class SafePlaywrightURLLoader(PlaywrightURLLoader): + """Load HTML pages safely with Playwright, supporting SSL verification and rate limiting. + + Attributes: + urls (List[str]): List of URLs to load. + verify_ssl (bool): If True, verify SSL certificates. + requests_per_second (Optional[float]): Number of requests per second to limit to. + continue_on_failure (bool): If True, continue loading other URLs on failure. + headless (bool): If True, the browser will run in headless mode. + """ + + def __init__( + self, + urls: List[str], + verify_ssl: bool = True, + requests_per_second: Optional[float] = None, + continue_on_failure: bool = True, + headless: bool = True, + remove_selectors: Optional[List[str]] = None, + proxy: Optional[Dict[str, str]] = None + ): + """Initialize with additional safety parameters.""" + super().__init__( + urls=urls, + continue_on_failure=continue_on_failure, + headless=headless, + remove_selectors=remove_selectors, + proxy=proxy + ) + self.verify_ssl = verify_ssl + self.requests_per_second = requests_per_second + self.last_request_time = None + + def _verify_ssl_cert(self, url: str) -> bool: + """Verify SSL certificate for the given URL.""" + if not url.startswith("https://"): + return True + + try: + hostname = url.split("://")[-1].split("/")[0] + context = ssl.create_default_context(cafile=certifi.where()) + with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s: + s.connect((hostname, 443)) + return True + except ssl.SSLError: + return False + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + return False + + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + async def alazy_load(self) -> AsyncIterator[Document]: + """Safely load URLs asynchronously.""" + parent_iterator = super().alazy_load() + + async for document in parent_iterator: + url = document.metadata["source"] + try: + await self._safe_process_url(url) + yield document + except Exception as e: + if self.continue_on_failure: + log.error(f"Error processing {url}, exception: {e}") + continue + raise e + + def lazy_load(self) -> Iterator[Document]: + """Safely load URLs synchronously.""" + parent_iterator = super().lazy_load() + + for document in parent_iterator: + url = document.metadata["source"] + try: + self._safe_process_url_sync(url) + yield document + except Exception as e: + if self.continue_on_failure: + log.error(f"Error processing {url}, exception: {e}") + continue + raise e class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" @@ -65,15 +204,7 @@ class SafeWebBaseLoader(WebBaseLoader): text = soup.get_text(**self.bs_get_text_kwargs) # Build metadata - metadata = {"source": path} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get( - "content", "No description found." - ) - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") + metadata = extract_metadata(soup, path) yield Document(page_content=text, metadata=metadata) except Exception as e: @@ -87,11 +218,21 @@ def get_web_loader( requests_per_second: int = 2, ): # Check if the URL is valid - if not validate_url(urls): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - return SafeWebBaseLoader( - urls, - verify_ssl=verify_ssl, - requests_per_second=requests_per_second, - continue_on_failure=True, - ) + safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) + + if RAG_WEB_LOADER.value == "chromium": + log.info("Using SafePlaywrightURLLoader") + return SafePlaywrightURLLoader( + safe_urls, + verify_ssl=verify_ssl, + requests_per_second=requests_per_second, + continue_on_failure=True, + ) + else: + log.info("Using SafeWebBaseLoader") + return SafeWebBaseLoader( + safe_urls, + verify_ssl=verify_ssl, + requests_per_second=requests_per_second, + continue_on_failure=True, + ) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 2cffd9ead..e65a76050 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1238,9 +1238,11 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: @router.post("/process/web/search") -def process_web_search( - request: Request, form_data: SearchForm, user=Depends(get_verified_user) +async def process_web_search( + request: Request, form_data: SearchForm, extra_params: dict, user=Depends(get_verified_user) ): + event_emitter = extra_params["__event_emitter__"] + try: logging.info( f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" @@ -1258,6 +1260,18 @@ def process_web_search( log.debug(f"web_results: {web_results}") + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "Loading {{count}} sites...", + "urls": [result.link for result in web_results], + "done": False + }, + } + ) + try: collection_name = form_data.collection_name if collection_name == "" or collection_name is None: @@ -1271,7 +1285,8 @@ def process_web_search( verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) - docs = loader.load() + docs = [doc async for doc in loader.alazy_load()] + # docs = loader.load() save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 6b2329be1..27e499e0c 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -419,21 +419,16 @@ async def chat_web_search_handler( try: - # Offload process_web_search to a separate thread - loop = asyncio.get_running_loop() - with ThreadPoolExecutor() as executor: - results = await loop.run_in_executor( - executor, - lambda: process_web_search( - request, - SearchForm( - **{ - "query": searchQuery, - } - ), - user, - ), - ) + results = await process_web_search( + request, + SearchForm( + **{ + "query": searchQuery, + } + ), + extra_params=extra_params, + user=user + ) if results: await event_emitter( @@ -441,7 +436,7 @@ async def chat_web_search_handler( "type": "status", "data": { "action": "web_search", - "description": "Searched {{count}} sites", + "description": "Loaded {{count}} sites", "query": searchQuery, "urls": results["filenames"], "done": True, diff --git a/backend/requirements.txt b/backend/requirements.txt index eecb9c4a5..0dd7b1a8a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -46,7 +46,7 @@ chromadb==0.6.2 pymilvus==2.5.0 qdrant-client~=1.12.0 opensearch-py==2.7.1 - +playwright==1.49.1 transformers sentence-transformers==3.3.1 diff --git a/backend/start.sh b/backend/start.sh index a945acb62..ce56b1867 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -3,6 +3,15 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cd "$SCRIPT_DIR" || exit +# Add conditional Playwright browser installation +if [[ "${RAG_WEB_LOADER,,}" == "chromium" ]]; then + echo "Installing Playwright browsers..." + playwright install chromium + playwright install-deps chromium + + python -c "import nltk; nltk.download('punkt_tab')" +fi + KEY_FILE=.webui_secret_key PORT="${PORT:-8080}" diff --git a/backend/start_windows.bat b/backend/start_windows.bat index 3e8c6b97c..3b6446258 100644 --- a/backend/start_windows.bat +++ b/backend/start_windows.bat @@ -6,6 +6,15 @@ SETLOCAL ENABLEDELAYEDEXPANSION SET "SCRIPT_DIR=%~dp0" cd /d "%SCRIPT_DIR%" || exit /b +:: Add conditional Playwright browser installation +IF /I "%RAG_WEB_LOADER%" == "chromium" ( + echo Installing Playwright browsers... + playwright install chromium + playwright install-deps chromium + + python -c "import nltk; nltk.download('punkt_tab')" +) + SET "KEY_FILE=.webui_secret_key" IF "%PORT%"=="" SET PORT=8080 IF "%HOST%"=="" SET HOST=0.0.0.0 diff --git a/pyproject.toml b/pyproject.toml index edd01db8f..c8ec0f497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "pymilvus==2.5.0", "qdrant-client~=1.12.0", "opensearch-py==2.7.1", + "playwright==1.49.1", "transformers", "sentence-transformers==3.3.1",