From 34ec10a78c951b65877cb5772b3ff05a38a2e51e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 10 May 2025 17:54:41 +0400 Subject: [PATCH] refac: web search performance Co-Authored-By: Mabeck <64421281+mmabeck@users.noreply.github.com> --- backend/open_webui/routers/retrieval.py | 81 ++++++++------- backend/open_webui/utils/middleware.py | 131 +++++++++--------------- 2 files changed, 95 insertions(+), 117 deletions(-) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 58c695eb3..b86d8968d 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -3,6 +3,8 @@ import logging import mimetypes import os import shutil +import asyncio + import uuid from datetime import datetime @@ -188,7 +190,7 @@ class ProcessUrlForm(CollectionNameForm): class SearchForm(BaseModel): - query: str + queries: List[str] @router.get("/") @@ -1568,16 +1570,34 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: async def process_web_search( request: Request, form_data: SearchForm, user=Depends(get_verified_user) ): + + urls = [] try: logging.info( f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}" ) - web_results = await run_in_threadpool( - search_web, - request, - request.app.state.config.WEB_SEARCH_ENGINE, - form_data.query, - ) + + search_tasks = [ + run_in_threadpool( + search_web, + request, + request.app.state.config.WEB_SEARCH_ENGINE, + query, + ) + for query in form_data.queries + ] + + search_results = await asyncio.gather(*search_tasks) + + for result in search_results: + if result: + for item in result: + if item and item.link: + urls.append(item.link) + + urls = list(dict.fromkeys(urls)) + log.debug(f"urls: {urls}") + except Exception as e: log.exception(e) @@ -1586,15 +1606,7 @@ async def process_web_search( detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), ) - log.debug(f"web_results: {web_results}") - try: - urls = [result.link for result in web_results] - - # Remove duplicates - urls = list(dict.fromkeys(urls)) - log.debug(f"urls: {urls}") - loader = get_web_loader( urls, verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, @@ -1604,7 +1616,7 @@ async def process_web_search( docs = await loader.aload() urls = [ doc.metadata.get("source") for doc in docs if doc.metadata.get("source") - ] # only keep URLs + ] # only keep the urls returned by the loader if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: return { @@ -1621,29 +1633,28 @@ async def process_web_search( "loaded_count": len(docs), } else: - collection_names = [] - for doc_idx, doc in enumerate(docs): - if doc and doc.page_content: - try: - collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[ - :63 - ] + # Create a single collection for all documents + collection_name = ( + f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[ + :63 + ] + ) - collection_names.append(collection_name) - await run_in_threadpool( - save_docs_to_vector_db, - request, - [doc], - collection_name, - overwrite=True, - user=user, - ) - except Exception as e: - log.debug(f"error saving doc {doc_idx}: {e}") + try: + await run_in_threadpool( + save_docs_to_vector_db, + request, + docs, + collection_name, + overwrite=True, + user=user, + ) + except Exception as e: + log.debug(f"error saving docs: {e}") return { "status": True, - "collection_names": collection_names, + "collection_names": [collection_name], "filenames": urls, "loaded_count": len(docs), } diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 55e3a0d92..442dfba76 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -353,8 +353,6 @@ async def chat_web_search_handler( ) return form_data - all_results = [] - await event_emitter( { "type": "status", @@ -366,106 +364,75 @@ async def chat_web_search_handler( } ) - gathered_results = await asyncio.gather( - *( - process_web_search( - request, - SearchForm(**{"query": searchQuery}), - user=user, - ) - for searchQuery in queries - ), - return_exceptions=True, - ) + try: + results = await process_web_search( + request, + SearchForm(queries=queries), + user=user, + ) - for searchQuery, results in zip(queries, gathered_results): - try: - if isinstance(results, Exception): - raise Exception(f"Error searching {searchQuery}: {str(results)}") + if results: + files = form_data.get("files", []) - if results: - all_results.append(results) - files = form_data.get("files", []) + if results.get("collection_names"): + for col_idx, collection_name in enumerate( + results.get("collection_names") + ): + files.append( + { + "collection_name": collection_name, + "name": ", ".join(queries), + "type": "web_search", + "urls": results["filenames"], + } + ) + elif results.get("docs"): + # Invoked when bypass embedding and retrieval is set to True + docs = results["docs"] + files.append( + { + "docs": docs, + "name": ", ".join(queries), + "type": "web_search", + "urls": results["filenames"], + } + ) - if results.get("collection_names"): - for col_idx, collection_name in enumerate( - results.get("collection_names") - ): - files.append( - { - "collection_name": collection_name, - "name": searchQuery, - "type": "web_search", - "urls": [results["filenames"][col_idx]], - } - ) - elif results.get("docs"): - # Invoked when bypass embedding and retrieval is set to True - docs = results["docs"] + form_data["files"] = files - if len(docs) == len(results["filenames"]): - # the number of docs and filenames (urls) should be the same - for doc_idx, doc in enumerate(docs): - files.append( - { - "docs": [doc], - "name": searchQuery, - "type": "web_search", - "urls": [results["filenames"][doc_idx]], - } - ) - else: - # edge case when the number of docs and filenames (urls) are not the same - # this should not happen, but if it does, we will just append the docs - files.append( - { - "docs": results.get("docs", []), - "name": searchQuery, - "type": "web_search", - "urls": results["filenames"], - } - ) - - form_data["files"] = files - except Exception as e: - log.exception(e) await event_emitter( { "type": "status", "data": { "action": "web_search", - "description": 'Error searching "{{searchQuery}}"', - "query": searchQuery, + "description": "Searched {{count}} sites", + "urls": results["filenames"], + "done": True, + }, + } + ) + else: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "No search results found", "done": True, "error": True, }, } ) - if all_results: - urls = [] - for results in all_results: - if "filenames" in results: - urls.extend(results["filenames"]) - + except Exception as e: + log.exception(e) await event_emitter( { "type": "status", "data": { "action": "web_search", - "description": "Searched {{count}} sites", - "urls": urls, - "done": True, - }, - } - ) - else: - await event_emitter( - { - "type": "status", - "data": { - "action": "web_search", - "description": "No search results found", + "description": "An error occurred while searching the web", + "queries": queries, "done": True, "error": True, },