From e500461dc0dba17d6dcbcdc474cfb70c35e78fcf Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 17 Dec 2024 18:40:50 -0800 Subject: [PATCH] refac --- backend/open_webui/retrieval/web/utils.py | 6 +-- backend/open_webui/routers/retrieval.py | 49 ++++++++++------------- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 2df98b33c..a322bbbfc 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader): def get_web_loader( - url: Union[str, Sequence[str]], + urls: Union[str, Sequence[str]], verify_ssl: bool = True, requests_per_second: int = 2, ): # Check if the URL is valid - if not validate_url(url): + if not validate_url(urls): raise ValueError(ERROR_MESSAGES.INVALID_URL) return SafeWebBaseLoader( - url, + urls, verify_ssl=verify_ssl, requests_per_second=requests_per_second, continue_on_failure=True, diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 1898bfe49..c6a3a0cca 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1256,7 +1256,7 @@ def process_web_search( urls = [result.link for result in web_results] loader = get_web_loader( - urls=urls, + urls, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) @@ -1429,19 +1429,23 @@ if ENV == "dev": async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): return {"result": request.app.state.EMBEDDING_FUNCTION(text)} + class BatchProcessFilesForm(BaseModel): files: List[FileModel] collection_name: str + class BatchProcessFilesResult(BaseModel): file_id: str status: str error: Optional[str] = None + class BatchProcessFilesResponse(BaseModel): results: List[BatchProcessFilesResult] errors: List[BatchProcessFilesResult] + @router.post("/process/files/batch") def process_files_batch( form_data: BatchProcessFilesForm, @@ -1459,7 +1463,7 @@ def process_files_batch( for file in form_data.files: try: text_content = file.data.get("content", "") - + docs: List[Document] = [ Document( page_content=text_content.replace("
", "\n"), @@ -1476,49 +1480,38 @@ def process_files_batch( hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) Files.update_file_data_by_id(file.id, {"content": text_content}) - + all_docs.extend(docs) - results.append(BatchProcessFilesResult( - file_id=file.id, - status="prepared" - )) + results.append(BatchProcessFilesResult(file_id=file.id, status="prepared")) except Exception as e: log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}") - errors.append(BatchProcessFilesResult( - file_id=file.id, - status="failed", - error=str(e) - )) + errors.append( + BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e)) + ) # Save all documents in one batch if all_docs: try: save_docs_to_vector_db( - docs=all_docs, - collection_name=collection_name, - add=True + docs=all_docs, collection_name=collection_name, add=True ) - + # Update all files with collection name for result in results: Files.update_file_metadata_by_id( - result.file_id, - {"collection_name": collection_name} + result.file_id, {"collection_name": collection_name} ) result.status = "completed" except Exception as e: - log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}") + log.error( + f"process_files_batch: Error saving documents to vector DB: {str(e)}" + ) for result in results: result.status = "failed" - errors.append(BatchProcessFilesResult( - file_id=result.file_id, - error=str(e) - )) - - return BatchProcessFilesResponse( - results=results, - errors=errors - ) + errors.append( + BatchProcessFilesResult(file_id=result.file_id, error=str(e)) + ) + return BatchProcessFilesResponse(results=results, errors=errors)