refac: web search performance

Co-Authored-By: Mabeck <64421281+mmabeck@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-05-10 17:54:41 +04:00
parent 015ac2f532
commit 34ec10a78c
2 changed files with 95 additions and 117 deletions

View File

@ -3,6 +3,8 @@ import logging
import mimetypes import mimetypes
import os import os
import shutil import shutil
import asyncio
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -188,7 +190,7 @@ class ProcessUrlForm(CollectionNameForm):
class SearchForm(BaseModel): class SearchForm(BaseModel):
query: str queries: List[str]
@router.get("/") @router.get("/")
@ -1568,16 +1570,34 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
async def process_web_search( async def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user) request: Request, form_data: SearchForm, user=Depends(get_verified_user)
): ):
urls = []
try: try:
logging.info( logging.info(
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}" f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}"
) )
web_results = await run_in_threadpool(
search_web, search_tasks = [
request, run_in_threadpool(
request.app.state.config.WEB_SEARCH_ENGINE, search_web,
form_data.query, request,
) request.app.state.config.WEB_SEARCH_ENGINE,
query,
)
for query in form_data.queries
]
search_results = await asyncio.gather(*search_tasks)
for result in search_results:
if result:
for item in result:
if item and item.link:
urls.append(item.link)
urls = list(dict.fromkeys(urls))
log.debug(f"urls: {urls}")
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -1586,15 +1606,7 @@ async def process_web_search(
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
) )
log.debug(f"web_results: {web_results}")
try: try:
urls = [result.link for result in web_results]
# Remove duplicates
urls = list(dict.fromkeys(urls))
log.debug(f"urls: {urls}")
loader = get_web_loader( loader = get_web_loader(
urls, urls,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
@ -1604,7 +1616,7 @@ async def process_web_search(
docs = await loader.aload() docs = await loader.aload()
urls = [ urls = [
doc.metadata.get("source") for doc in docs if doc.metadata.get("source") doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
] # only keep URLs ] # only keep the urls returned by the loader
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
return { return {
@ -1621,29 +1633,28 @@ async def process_web_search(
"loaded_count": len(docs), "loaded_count": len(docs),
} }
else: else:
collection_names = [] # Create a single collection for all documents
for doc_idx, doc in enumerate(docs): collection_name = (
if doc and doc.page_content: f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[
try: :63
collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[ ]
:63 )
]
collection_names.append(collection_name) try:
await run_in_threadpool( await run_in_threadpool(
save_docs_to_vector_db, save_docs_to_vector_db,
request, request,
[doc], docs,
collection_name, collection_name,
overwrite=True, overwrite=True,
user=user, user=user,
) )
except Exception as e: except Exception as e:
log.debug(f"error saving doc {doc_idx}: {e}") log.debug(f"error saving docs: {e}")
return { return {
"status": True, "status": True,
"collection_names": collection_names, "collection_names": [collection_name],
"filenames": urls, "filenames": urls,
"loaded_count": len(docs), "loaded_count": len(docs),
} }

View File

@ -353,8 +353,6 @@ async def chat_web_search_handler(
) )
return form_data return form_data
all_results = []
await event_emitter( await event_emitter(
{ {
"type": "status", "type": "status",
@ -366,106 +364,75 @@ async def chat_web_search_handler(
} }
) )
gathered_results = await asyncio.gather( try:
*( results = await process_web_search(
process_web_search( request,
request, SearchForm(queries=queries),
SearchForm(**{"query": searchQuery}), user=user,
user=user, )
)
for searchQuery in queries
),
return_exceptions=True,
)
for searchQuery, results in zip(queries, gathered_results): if results:
try: files = form_data.get("files", [])
if isinstance(results, Exception):
raise Exception(f"Error searching {searchQuery}: {str(results)}")
if results: if results.get("collection_names"):
all_results.append(results) for col_idx, collection_name in enumerate(
files = form_data.get("files", []) results.get("collection_names")
):
files.append(
{
"collection_name": collection_name,
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
}
)
elif results.get("docs"):
# Invoked when bypass embedding and retrieval is set to True
docs = results["docs"]
files.append(
{
"docs": docs,
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
}
)
if results.get("collection_names"): form_data["files"] = files
for col_idx, collection_name in enumerate(
results.get("collection_names")
):
files.append(
{
"collection_name": collection_name,
"name": searchQuery,
"type": "web_search",
"urls": [results["filenames"][col_idx]],
}
)
elif results.get("docs"):
# Invoked when bypass embedding and retrieval is set to True
docs = results["docs"]
if len(docs) == len(results["filenames"]):
# the number of docs and filenames (urls) should be the same
for doc_idx, doc in enumerate(docs):
files.append(
{
"docs": [doc],
"name": searchQuery,
"type": "web_search",
"urls": [results["filenames"][doc_idx]],
}
)
else:
# edge case when the number of docs and filenames (urls) are not the same
# this should not happen, but if it does, we will just append the docs
files.append(
{
"docs": results.get("docs", []),
"name": searchQuery,
"type": "web_search",
"urls": results["filenames"],
}
)
form_data["files"] = files
except Exception as e:
log.exception(e)
await event_emitter( await event_emitter(
{ {
"type": "status", "type": "status",
"data": { "data": {
"action": "web_search", "action": "web_search",
"description": 'Error searching "{{searchQuery}}"', "description": "Searched {{count}} sites",
"query": searchQuery, "urls": results["filenames"],
"done": True,
},
}
)
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"done": True, "done": True,
"error": True, "error": True,
}, },
} }
) )
if all_results: except Exception as e:
urls = [] log.exception(e)
for results in all_results:
if "filenames" in results:
urls.extend(results["filenames"])
await event_emitter( await event_emitter(
{ {
"type": "status", "type": "status",
"data": { "data": {
"action": "web_search", "action": "web_search",
"description": "Searched {{count}} sites", "description": "An error occurred while searching the web",
"urls": urls, "queries": queries,
"done": True,
},
}
)
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"done": True, "done": True,
"error": True, "error": True,
}, },