diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 3a73444b3..7fdf2f3c6 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -1,7 +1,9 @@ import socket +import aiohttp +import asyncio import urllib.parse import validators -from typing import Union, Sequence, Iterator +from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union from langchain_community.document_loaders import ( WebBaseLoader, @@ -68,6 +70,31 @@ def resolve_hostname(hostname): class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" + def _unpack_fetch_results( + self, results: Any, urls: List[str], parser: Union[str, None] = None + ) -> List[Any]: + """Unpack fetch results into BeautifulSoup objects.""" + from bs4 import BeautifulSoup + + final_results = [] + for i, result in enumerate(results): + url = urls[i] + if parser is None: + if url.endswith(".xml"): + parser = "xml" + else: + parser = self.default_parser + self._check_parser(parser) + final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs)) + return final_results + + async def ascrape_all( + self, urls: List[str], parser: Union[str, None] = None + ) -> List[Any]: + """Async fetch all urls, then return soups for all results.""" + results = await self.fetch_all(urls) + return self._unpack_fetch_results(results, urls, parser=parser) + def lazy_load(self) -> Iterator[Document]: """Lazy load text from the url(s) in web_path with error handling.""" for path in self.web_paths: @@ -91,6 +118,26 @@ class SafeWebBaseLoader(WebBaseLoader): # Log the error and continue with the next URL log.error(f"Error loading {path}: {e}") + async def alazy_load(self) -> AsyncIterator[Document]: + """Async lazy load text from the url(s) in web_path.""" + results = await self.ascrape_all(self.web_paths) + for path, soup in zip(self.web_paths, results): + text = soup.get_text(**self.bs_get_text_kwargs) + metadata = {"source": path} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get( + "content", "No description found." + ) + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + yield Document(page_content=text, metadata=metadata) + + async def aload(self) -> list[Document]: + """Load data into Document objects.""" + return [document async for document in self.alazy_load()] + def get_web_loader( urls: Union[str, Sequence[str]], diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 415d3bbb5..c8ca1d85b 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -21,6 +21,7 @@ from fastapi import ( APIRouter, ) from fastapi.middleware.cors import CORSMiddleware +from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel import tiktoken @@ -1308,7 +1309,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: @router.post("/process/web/search") -def process_web_search( +async def process_web_search( request: Request, form_data: SearchForm, user=Depends(get_verified_user) ): try: @@ -1341,15 +1342,21 @@ def process_web_search( verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) - docs = loader.load() - save_docs_to_vector_db( - request, docs, collection_name, overwrite=True, user=user + docs = await loader.aload() + await run_in_threadpool( + save_docs_to_vector_db, + request, + docs, + collection_name, + overwrite=True, + user=user, ) return { "status": True, "collection_name": collection_name, "filenames": urls, + "loaded_count": len(docs), } except Exception as e: log.exception(e) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 4e4ba8d30..f0cdede3d 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -334,21 +334,15 @@ async def chat_web_search_handler( try: - # Offload process_web_search to a separate thread - loop = asyncio.get_running_loop() - with ThreadPoolExecutor() as executor: - results = await loop.run_in_executor( - executor, - lambda: process_web_search( - request, - SearchForm( - **{ - "query": searchQuery, - } - ), - user, - ), - ) + results = await process_web_search( + request, + SearchForm( + **{ + "query": searchQuery, + } + ), + user, + ) if results: await event_emitter(