Add RAG_WEB_LOADER + Playwright mode + improve stability of search

This commit is contained in:
Rory 2025-01-28 23:03:15 -06:00
parent b72150c881
commit 4e8b390682
10 changed files with 220 additions and 39 deletions

View File

@ -1712,6 +1712,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
RAG_WEB_LOADER = PersistentConfig(
"RAG_WEB_LOADER",
"rag.web.loader",
os.environ.get("RAG_WEB_LOADER", "safe_web")
)
####################################
# Images

View File

@ -129,6 +129,7 @@ from open_webui.config import (
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
RAG_WEB_LOADER,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
@ -526,6 +527,7 @@ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_K
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.RAG_WEB_LOADER = RAG_WEB_LOADER
app.state.EMBEDDING_FUNCTION = None
app.state.ef = None

View File

@ -1,3 +1,5 @@
import validators
from typing import Optional
from urllib.parse import urlparse
@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list):
filtered_results = []
for result in results:
url = result.get("url") or result.get("link", "")
if not validators.url(url):
continue
domain = urlparse(url).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)

View File

@ -1,16 +1,21 @@
import asyncio
from datetime import datetime, time, timedelta
import socket
import ssl
import urllib.parse
import certifi
import validators
from typing import Union, Sequence, Iterator
from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
from langchain_community.document_loaders import (
WebBaseLoader,
PlaywrightURLLoader
)
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, RAG_WEB_LOADER
from open_webui.env import SRC_LOG_LEVELS
import logging
@ -42,6 +47,15 @@ def validate_url(url: Union[str, Sequence[str]]):
else:
return False
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
valid_urls = []
for u in url:
try:
if validate_url(u):
valid_urls.append(u)
except ValueError:
continue
return valid_urls
def resolve_hostname(hostname):
# Get address information
@ -53,6 +67,131 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {
"source": url
}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
return metadata
class SafePlaywrightURLLoader(PlaywrightURLLoader):
"""Load HTML pages safely with Playwright, supporting SSL verification and rate limiting.
Attributes:
urls (List[str]): List of URLs to load.
verify_ssl (bool): If True, verify SSL certificates.
requests_per_second (Optional[float]): Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
headless (bool): If True, the browser will run in headless mode.
"""
def __init__(
self,
urls: List[str],
verify_ssl: bool = True,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
headless: bool = True,
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None
):
"""Initialize with additional safety parameters."""
super().__init__(
urls=urls,
continue_on_failure=continue_on_failure,
headless=headless,
remove_selectors=remove_selectors,
proxy=proxy
)
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
def _verify_ssl_cert(self, url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
context = ssl.create_default_context(cafile=certifi.where())
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
s.connect((hostname, 443))
return True
except ssl.SSLError:
return False
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
return False
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
async def alazy_load(self) -> AsyncIterator[Document]:
"""Safely load URLs asynchronously."""
parent_iterator = super().alazy_load()
async for document in parent_iterator:
url = document.metadata["source"]
try:
await self._safe_process_url(url)
yield document
except Exception as e:
if self.continue_on_failure:
log.error(f"Error processing {url}, exception: {e}")
continue
raise e
def lazy_load(self) -> Iterator[Document]:
"""Safely load URLs synchronously."""
parent_iterator = super().lazy_load()
for document in parent_iterator:
url = document.metadata["source"]
try:
self._safe_process_url_sync(url)
yield document
except Exception as e:
if self.continue_on_failure:
log.error(f"Error processing {url}, exception: {e}")
continue
raise e
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
@ -65,15 +204,7 @@ class SafeWebBaseLoader(WebBaseLoader):
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
metadata = extract_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)
except Exception as e:
@ -87,11 +218,21 @@ def get_web_loader(
requests_per_second: int = 2,
):
# Check if the URL is valid
if not validate_url(urls):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return SafeWebBaseLoader(
urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
)
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
if RAG_WEB_LOADER.value == "chromium":
log.info("Using SafePlaywrightURLLoader")
return SafePlaywrightURLLoader(
safe_urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
)
else:
log.info("Using SafeWebBaseLoader")
return SafeWebBaseLoader(
safe_urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
)

View File

@ -1238,9 +1238,11 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
@router.post("/process/web/search")
def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
async def process_web_search(
request: Request, form_data: SearchForm, extra_params: dict, user=Depends(get_verified_user)
):
event_emitter = extra_params["__event_emitter__"]
try:
logging.info(
f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
@ -1258,6 +1260,18 @@ def process_web_search(
log.debug(f"web_results: {web_results}")
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Loading {{count}} sites...",
"urls": [result.link for result in web_results],
"done": False
},
}
)
try:
collection_name = form_data.collection_name
if collection_name == "" or collection_name is None:
@ -1271,7 +1285,8 @@ def process_web_search(
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
)
docs = loader.load()
docs = [doc async for doc in loader.alazy_load()]
# docs = loader.load()
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
return {

View File

@ -419,21 +419,16 @@ async def chat_web_search_handler(
try:
# Offload process_web_search to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
results = await loop.run_in_executor(
executor,
lambda: process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
}
),
user,
),
)
results = await process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
}
),
extra_params=extra_params,
user=user
)
if results:
await event_emitter(
@ -441,7 +436,7 @@ async def chat_web_search_handler(
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
"description": "Loaded {{count}} sites",
"query": searchQuery,
"urls": results["filenames"],
"done": True,

View File

@ -46,7 +46,7 @@ chromadb==0.6.2
pymilvus==2.5.0
qdrant-client~=1.12.0
opensearch-py==2.7.1
playwright==1.49.1
transformers
sentence-transformers==3.3.1

View File

@ -3,6 +3,15 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd "$SCRIPT_DIR" || exit
# Add conditional Playwright browser installation
if [[ "${RAG_WEB_LOADER,,}" == "chromium" ]]; then
echo "Installing Playwright browsers..."
playwright install chromium
playwright install-deps chromium
python -c "import nltk; nltk.download('punkt_tab')"
fi
KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}"

View File

@ -6,6 +6,15 @@ SETLOCAL ENABLEDELAYEDEXPANSION
SET "SCRIPT_DIR=%~dp0"
cd /d "%SCRIPT_DIR%" || exit /b
:: Add conditional Playwright browser installation
IF /I "%RAG_WEB_LOADER%" == "chromium" (
echo Installing Playwright browsers...
playwright install chromium
playwright install-deps chromium
python -c "import nltk; nltk.download('punkt_tab')"
)
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0

View File

@ -53,6 +53,7 @@ dependencies = [
"pymilvus==2.5.0",
"qdrant-client~=1.12.0",
"opensearch-py==2.7.1",
"playwright==1.49.1",
"transformers",
"sentence-transformers==3.3.1",