mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Add RAG_WEB_LOADER + Playwright mode + improve stability of search
This commit is contained in:
		
							parent
							
								
									b72150c881
								
							
						
					
					
						commit
						4e8b390682
					
				@ -1712,6 +1712,11 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
 | 
			
		||||
    int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_WEB_LOADER = PersistentConfig(
 | 
			
		||||
    "RAG_WEB_LOADER",
 | 
			
		||||
    "rag.web.loader",
 | 
			
		||||
    os.environ.get("RAG_WEB_LOADER", "safe_web")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
####################################
 | 
			
		||||
# Images
 | 
			
		||||
 | 
			
		||||
@ -129,6 +129,7 @@ from open_webui.config import (
 | 
			
		||||
    AUDIO_TTS_VOICE,
 | 
			
		||||
    AUDIO_TTS_AZURE_SPEECH_REGION,
 | 
			
		||||
    AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
 | 
			
		||||
    RAG_WEB_LOADER,
 | 
			
		||||
    WHISPER_MODEL,
 | 
			
		||||
    WHISPER_MODEL_AUTO_UPDATE,
 | 
			
		||||
    WHISPER_MODEL_DIR,
 | 
			
		||||
@ -526,6 +527,7 @@ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_K
 | 
			
		||||
 | 
			
		||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 | 
			
		||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
 | 
			
		||||
app.state.config.RAG_WEB_LOADER = RAG_WEB_LOADER
 | 
			
		||||
 | 
			
		||||
app.state.EMBEDDING_FUNCTION = None
 | 
			
		||||
app.state.ef = None
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,5 @@
 | 
			
		||||
import validators
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list):
 | 
			
		||||
    filtered_results = []
 | 
			
		||||
    for result in results:
 | 
			
		||||
        url = result.get("url") or result.get("link", "")
 | 
			
		||||
        if not validators.url(url):
 | 
			
		||||
            continue
 | 
			
		||||
        domain = urlparse(url).netloc
 | 
			
		||||
        if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
 | 
			
		||||
            filtered_results.append(result)
 | 
			
		||||
 | 
			
		||||
@ -1,16 +1,21 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
from datetime import datetime, time, timedelta
 | 
			
		||||
import socket
 | 
			
		||||
import ssl
 | 
			
		||||
import urllib.parse
 | 
			
		||||
import certifi
 | 
			
		||||
import validators
 | 
			
		||||
from typing import Union, Sequence, Iterator
 | 
			
		||||
from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
 | 
			
		||||
 | 
			
		||||
from langchain_community.document_loaders import (
 | 
			
		||||
    WebBaseLoader,
 | 
			
		||||
    PlaywrightURLLoader
 | 
			
		||||
)
 | 
			
		||||
from langchain_core.documents import Document
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from open_webui.constants import ERROR_MESSAGES
 | 
			
		||||
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
 | 
			
		||||
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH, RAG_WEB_LOADER
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
@ -42,6 +47,15 @@ def validate_url(url: Union[str, Sequence[str]]):
 | 
			
		||||
    else:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
 | 
			
		||||
    valid_urls = []
 | 
			
		||||
    for u in url:
 | 
			
		||||
        try:
 | 
			
		||||
            if validate_url(u):
 | 
			
		||||
                valid_urls.append(u)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            continue
 | 
			
		||||
    return valid_urls
 | 
			
		||||
 | 
			
		||||
def resolve_hostname(hostname):
 | 
			
		||||
    # Get address information
 | 
			
		||||
@ -53,6 +67,131 @@ def resolve_hostname(hostname):
 | 
			
		||||
 | 
			
		||||
    return ipv4_addresses, ipv6_addresses
 | 
			
		||||
 | 
			
		||||
def extract_metadata(soup, url):
 | 
			
		||||
    metadata = {
 | 
			
		||||
        "source": url
 | 
			
		||||
    }
 | 
			
		||||
    if title := soup.find("title"):
 | 
			
		||||
        metadata["title"] = title.get_text()
 | 
			
		||||
    if description := soup.find("meta", attrs={"name": "description"}):
 | 
			
		||||
        metadata["description"] = description.get(
 | 
			
		||||
            "content", "No description found."
 | 
			
		||||
        )
 | 
			
		||||
    if html := soup.find("html"):
 | 
			
		||||
        metadata["language"] = html.get("lang", "No language found.")
 | 
			
		||||
    return metadata
 | 
			
		||||
 | 
			
		||||
class SafePlaywrightURLLoader(PlaywrightURLLoader):
 | 
			
		||||
    """Load HTML pages safely with Playwright, supporting SSL verification and rate limiting.
 | 
			
		||||
    
 | 
			
		||||
    Attributes:
 | 
			
		||||
        urls (List[str]): List of URLs to load.
 | 
			
		||||
        verify_ssl (bool): If True, verify SSL certificates.
 | 
			
		||||
        requests_per_second (Optional[float]): Number of requests per second to limit to.
 | 
			
		||||
        continue_on_failure (bool): If True, continue loading other URLs on failure.
 | 
			
		||||
        headless (bool): If True, the browser will run in headless mode.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        urls: List[str],
 | 
			
		||||
        verify_ssl: bool = True,
 | 
			
		||||
        requests_per_second: Optional[float] = None,
 | 
			
		||||
        continue_on_failure: bool = True,
 | 
			
		||||
        headless: bool = True,
 | 
			
		||||
        remove_selectors: Optional[List[str]] = None,
 | 
			
		||||
        proxy: Optional[Dict[str, str]] = None
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize with additional safety parameters."""
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            urls=urls,
 | 
			
		||||
            continue_on_failure=continue_on_failure,
 | 
			
		||||
            headless=headless,
 | 
			
		||||
            remove_selectors=remove_selectors,
 | 
			
		||||
            proxy=proxy
 | 
			
		||||
        )
 | 
			
		||||
        self.verify_ssl = verify_ssl
 | 
			
		||||
        self.requests_per_second = requests_per_second
 | 
			
		||||
        self.last_request_time = None
 | 
			
		||||
 | 
			
		||||
    def _verify_ssl_cert(self, url: str) -> bool:
 | 
			
		||||
        """Verify SSL certificate for the given URL."""
 | 
			
		||||
        if not url.startswith("https://"):
 | 
			
		||||
            return True
 | 
			
		||||
            
 | 
			
		||||
        try:
 | 
			
		||||
            hostname = url.split("://")[-1].split("/")[0]
 | 
			
		||||
            context = ssl.create_default_context(cafile=certifi.where())
 | 
			
		||||
            with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
 | 
			
		||||
                s.connect((hostname, 443))
 | 
			
		||||
            return True
 | 
			
		||||
        except ssl.SSLError:
 | 
			
		||||
            return False
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.warning(f"SSL verification failed for {url}: {str(e)}")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    async def _wait_for_rate_limit(self):
 | 
			
		||||
        """Wait to respect the rate limit if specified."""
 | 
			
		||||
        if self.requests_per_second and self.last_request_time:
 | 
			
		||||
            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
 | 
			
		||||
            time_since_last = datetime.now() - self.last_request_time
 | 
			
		||||
            if time_since_last < min_interval:
 | 
			
		||||
                await asyncio.sleep((min_interval - time_since_last).total_seconds())
 | 
			
		||||
        self.last_request_time = datetime.now()
 | 
			
		||||
 | 
			
		||||
    def _sync_wait_for_rate_limit(self):
 | 
			
		||||
        """Synchronous version of rate limit wait."""
 | 
			
		||||
        if self.requests_per_second and self.last_request_time:
 | 
			
		||||
            min_interval = timedelta(seconds=1.0 / self.requests_per_second)
 | 
			
		||||
            time_since_last = datetime.now() - self.last_request_time
 | 
			
		||||
            if time_since_last < min_interval:
 | 
			
		||||
                time.sleep((min_interval - time_since_last).total_seconds())
 | 
			
		||||
        self.last_request_time = datetime.now()
 | 
			
		||||
 | 
			
		||||
    async def _safe_process_url(self, url: str) -> bool:
 | 
			
		||||
        """Perform safety checks before processing a URL."""
 | 
			
		||||
        if self.verify_ssl and not self._verify_ssl_cert(url):
 | 
			
		||||
            raise ValueError(f"SSL certificate verification failed for {url}")
 | 
			
		||||
        await self._wait_for_rate_limit()
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def _safe_process_url_sync(self, url: str) -> bool:
 | 
			
		||||
        """Synchronous version of safety checks."""
 | 
			
		||||
        if self.verify_ssl and not self._verify_ssl_cert(url):
 | 
			
		||||
            raise ValueError(f"SSL certificate verification failed for {url}")
 | 
			
		||||
        self._sync_wait_for_rate_limit()
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    async def alazy_load(self) -> AsyncIterator[Document]:
 | 
			
		||||
        """Safely load URLs asynchronously."""
 | 
			
		||||
        parent_iterator = super().alazy_load()
 | 
			
		||||
        
 | 
			
		||||
        async for document in parent_iterator:
 | 
			
		||||
            url = document.metadata["source"]
 | 
			
		||||
            try:
 | 
			
		||||
                await self._safe_process_url(url)
 | 
			
		||||
                yield document
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                if self.continue_on_failure:
 | 
			
		||||
                    log.error(f"Error processing {url}, exception: {e}")
 | 
			
		||||
                    continue
 | 
			
		||||
                raise e
 | 
			
		||||
 | 
			
		||||
    def lazy_load(self) -> Iterator[Document]:
 | 
			
		||||
        """Safely load URLs synchronously."""
 | 
			
		||||
        parent_iterator = super().lazy_load()
 | 
			
		||||
        
 | 
			
		||||
        for document in parent_iterator:
 | 
			
		||||
            url = document.metadata["source"]
 | 
			
		||||
            try:
 | 
			
		||||
                self._safe_process_url_sync(url)
 | 
			
		||||
                yield document
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                if self.continue_on_failure:
 | 
			
		||||
                    log.error(f"Error processing {url}, exception: {e}")
 | 
			
		||||
                    continue
 | 
			
		||||
                raise e
 | 
			
		||||
 | 
			
		||||
class SafeWebBaseLoader(WebBaseLoader):
 | 
			
		||||
    """WebBaseLoader with enhanced error handling for URLs."""
 | 
			
		||||
@ -65,15 +204,7 @@ class SafeWebBaseLoader(WebBaseLoader):
 | 
			
		||||
                text = soup.get_text(**self.bs_get_text_kwargs)
 | 
			
		||||
 | 
			
		||||
                # Build metadata
 | 
			
		||||
                metadata = {"source": path}
 | 
			
		||||
                if title := soup.find("title"):
 | 
			
		||||
                    metadata["title"] = title.get_text()
 | 
			
		||||
                if description := soup.find("meta", attrs={"name": "description"}):
 | 
			
		||||
                    metadata["description"] = description.get(
 | 
			
		||||
                        "content", "No description found."
 | 
			
		||||
                    )
 | 
			
		||||
                if html := soup.find("html"):
 | 
			
		||||
                    metadata["language"] = html.get("lang", "No language found.")
 | 
			
		||||
                metadata = extract_metadata(soup, path)
 | 
			
		||||
 | 
			
		||||
                yield Document(page_content=text, metadata=metadata)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
@ -87,11 +218,21 @@ def get_web_loader(
 | 
			
		||||
    requests_per_second: int = 2,
 | 
			
		||||
):
 | 
			
		||||
    # Check if the URL is valid
 | 
			
		||||
    if not validate_url(urls):
 | 
			
		||||
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
 | 
			
		||||
    return SafeWebBaseLoader(
 | 
			
		||||
        urls,
 | 
			
		||||
        verify_ssl=verify_ssl,
 | 
			
		||||
        requests_per_second=requests_per_second,
 | 
			
		||||
        continue_on_failure=True,
 | 
			
		||||
    )
 | 
			
		||||
    safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
 | 
			
		||||
 | 
			
		||||
    if RAG_WEB_LOADER.value == "chromium":
 | 
			
		||||
        log.info("Using SafePlaywrightURLLoader")
 | 
			
		||||
        return SafePlaywrightURLLoader(
 | 
			
		||||
            safe_urls,
 | 
			
		||||
            verify_ssl=verify_ssl,
 | 
			
		||||
            requests_per_second=requests_per_second,
 | 
			
		||||
            continue_on_failure=True,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        log.info("Using SafeWebBaseLoader")
 | 
			
		||||
        return SafeWebBaseLoader(
 | 
			
		||||
            safe_urls,
 | 
			
		||||
            verify_ssl=verify_ssl,
 | 
			
		||||
            requests_per_second=requests_per_second,
 | 
			
		||||
            continue_on_failure=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1238,9 +1238,11 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/process/web/search")
 | 
			
		||||
def process_web_search(
 | 
			
		||||
    request: Request, form_data: SearchForm, user=Depends(get_verified_user)
 | 
			
		||||
async def process_web_search(
 | 
			
		||||
    request: Request, form_data: SearchForm, extra_params: dict, user=Depends(get_verified_user)
 | 
			
		||||
):
 | 
			
		||||
    event_emitter = extra_params["__event_emitter__"]
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        logging.info(
 | 
			
		||||
            f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
 | 
			
		||||
@ -1258,6 +1260,18 @@ def process_web_search(
 | 
			
		||||
 | 
			
		||||
    log.debug(f"web_results: {web_results}")
 | 
			
		||||
 | 
			
		||||
    await event_emitter(
 | 
			
		||||
        {
 | 
			
		||||
            "type": "status",
 | 
			
		||||
            "data": {
 | 
			
		||||
                "action": "web_search",
 | 
			
		||||
                "description": "Loading {{count}} sites...",
 | 
			
		||||
                "urls": [result.link for result in web_results],
 | 
			
		||||
                "done": False
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        collection_name = form_data.collection_name
 | 
			
		||||
        if collection_name == "" or collection_name is None:
 | 
			
		||||
@ -1271,7 +1285,8 @@ def process_web_search(
 | 
			
		||||
            verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
			
		||||
            requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 | 
			
		||||
        )
 | 
			
		||||
        docs = loader.load()
 | 
			
		||||
        docs = [doc async for doc in loader.alazy_load()]
 | 
			
		||||
        # docs = loader.load()
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
 | 
			
		||||
@ -419,21 +419,16 @@ async def chat_web_search_handler(
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
 | 
			
		||||
        # Offload process_web_search to a separate thread
 | 
			
		||||
        loop = asyncio.get_running_loop()
 | 
			
		||||
        with ThreadPoolExecutor() as executor:
 | 
			
		||||
            results = await loop.run_in_executor(
 | 
			
		||||
                executor,
 | 
			
		||||
                lambda: process_web_search(
 | 
			
		||||
                    request,
 | 
			
		||||
                    SearchForm(
 | 
			
		||||
                        **{
 | 
			
		||||
                            "query": searchQuery,
 | 
			
		||||
                        }
 | 
			
		||||
                    ),
 | 
			
		||||
                    user,
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        results = await process_web_search(
 | 
			
		||||
            request,
 | 
			
		||||
            SearchForm(
 | 
			
		||||
                **{
 | 
			
		||||
                    "query": searchQuery,
 | 
			
		||||
                }
 | 
			
		||||
            ),
 | 
			
		||||
            extra_params=extra_params,
 | 
			
		||||
            user=user
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if results:
 | 
			
		||||
            await event_emitter(
 | 
			
		||||
@ -441,7 +436,7 @@ async def chat_web_search_handler(
 | 
			
		||||
                    "type": "status",
 | 
			
		||||
                    "data": {
 | 
			
		||||
                        "action": "web_search",
 | 
			
		||||
                        "description": "Searched {{count}} sites",
 | 
			
		||||
                        "description": "Loaded {{count}} sites",
 | 
			
		||||
                        "query": searchQuery,
 | 
			
		||||
                        "urls": results["filenames"],
 | 
			
		||||
                        "done": True,
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ chromadb==0.6.2
 | 
			
		||||
pymilvus==2.5.0
 | 
			
		||||
qdrant-client~=1.12.0
 | 
			
		||||
opensearch-py==2.7.1
 | 
			
		||||
 | 
			
		||||
playwright==1.49.1
 | 
			
		||||
 | 
			
		||||
transformers
 | 
			
		||||
sentence-transformers==3.3.1
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,15 @@
 | 
			
		||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
 | 
			
		||||
cd "$SCRIPT_DIR" || exit
 | 
			
		||||
 | 
			
		||||
# Add conditional Playwright browser installation
 | 
			
		||||
if [[ "${RAG_WEB_LOADER,,}" == "chromium" ]]; then
 | 
			
		||||
    echo "Installing Playwright browsers..."
 | 
			
		||||
    playwright install chromium
 | 
			
		||||
    playwright install-deps chromium
 | 
			
		||||
 | 
			
		||||
    python -c "import nltk; nltk.download('punkt_tab')"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
KEY_FILE=.webui_secret_key
 | 
			
		||||
 | 
			
		||||
PORT="${PORT:-8080}"
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,15 @@ SETLOCAL ENABLEDELAYEDEXPANSION
 | 
			
		||||
SET "SCRIPT_DIR=%~dp0"
 | 
			
		||||
cd /d "%SCRIPT_DIR%" || exit /b
 | 
			
		||||
 | 
			
		||||
:: Add conditional Playwright browser installation
 | 
			
		||||
IF /I "%RAG_WEB_LOADER%" == "chromium" (
 | 
			
		||||
    echo Installing Playwright browsers...
 | 
			
		||||
    playwright install chromium
 | 
			
		||||
    playwright install-deps chromium
 | 
			
		||||
 | 
			
		||||
    python -c "import nltk; nltk.download('punkt_tab')"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
SET "KEY_FILE=.webui_secret_key"
 | 
			
		||||
IF "%PORT%"=="" SET PORT=8080
 | 
			
		||||
IF "%HOST%"=="" SET HOST=0.0.0.0
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,7 @@ dependencies = [
 | 
			
		||||
    "pymilvus==2.5.0",
 | 
			
		||||
    "qdrant-client~=1.12.0",
 | 
			
		||||
    "opensearch-py==2.7.1",
 | 
			
		||||
    "playwright==1.49.1",
 | 
			
		||||
 | 
			
		||||
    "transformers",
 | 
			
		||||
    "sentence-transformers==3.3.1",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user