Merge pull request #9988 from Yimi81/feat-support-async-load

feat: websearch support async docs load
This commit is contained in:
Timothy Jaeryang Baek 2025-02-14 14:10:46 -08:00 committed by GitHub
commit 3e543691a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 20 deletions

View File

@ -1,7 +1,9 @@
import socket import socket
import aiohttp
import asyncio
import urllib.parse import urllib.parse
import validators import validators
from typing import Union, Sequence, Iterator from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
@ -68,6 +70,31 @@ def resolve_hostname(hostname):
class SafeWebBaseLoader(WebBaseLoader): class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs.""" """WebBaseLoader with enhanced error handling for URLs."""
def _unpack_fetch_results(
self, results: Any, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Unpack fetch results into BeautifulSoup objects."""
from bs4 import BeautifulSoup
final_results = []
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results
async def ascrape_all(
self, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Async fetch all urls, then return soups for all results."""
results = await self.fetch_all(urls)
return self._unpack_fetch_results(results, urls, parser=parser)
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling.""" """Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths: for path in self.web_paths:
@ -91,6 +118,26 @@ class SafeWebBaseLoader(WebBaseLoader):
# Log the error and continue with the next URL # Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}") log.error(f"Error loading {path}: {e}")
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
results = await self.ascrape_all(self.web_paths)
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
async def aload(self) -> list[Document]:
"""Load data into Document objects."""
return [document async for document in self.alazy_load()]
def get_web_loader( def get_web_loader(
urls: Union[str, Sequence[str]], urls: Union[str, Sequence[str]],

View File

@ -21,6 +21,7 @@ from fastapi import (
APIRouter, APIRouter,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel from pydantic import BaseModel
import tiktoken import tiktoken
@ -1308,7 +1309,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
@router.post("/process/web/search") @router.post("/process/web/search")
def process_web_search( async def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user) request: Request, form_data: SearchForm, user=Depends(get_verified_user)
): ):
try: try:
@ -1341,15 +1342,21 @@ def process_web_search(
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
) )
docs = loader.load() docs = await loader.aload()
save_docs_to_vector_db( await run_in_threadpool(
request, docs, collection_name, overwrite=True, user=user save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user,
) )
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,
"filenames": urls, "filenames": urls,
"loaded_count": len(docs),
} }
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View File

@ -334,12 +334,7 @@ async def chat_web_search_handler(
try: try:
# Offload process_web_search to a separate thread results = await process_web_search(
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
results = await loop.run_in_executor(
executor,
lambda: process_web_search(
request, request,
SearchForm( SearchForm(
**{ **{
@ -347,7 +342,6 @@ async def chat_web_search_handler(
} }
), ),
user, user,
),
) )
if results: if results: