From 2428878f4225f644b285ee6c1c1d251db96b165a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 28 Sep 2024 02:29:08 +0200 Subject: [PATCH] refac --- backend/open_webui/apps/retrieval/main.py | 96 ++++------------------- 1 file changed, 16 insertions(+), 80 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 3e1ec8854..497a5685d 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -246,10 +246,10 @@ app.add_middleware( class CollectionNameForm(BaseModel): - collection_name: Optional[str] = "test" + collection_name: Optional[str] = None -class UrlForm(CollectionNameForm): +class ProcessUrlForm(CollectionNameForm): url: str @@ -636,7 +636,6 @@ def store_data_in_vector_db( chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) - docs = text_splitter.split_documents(data) if len(docs) > 0: @@ -715,66 +714,6 @@ def store_docs_in_vector_db( return False -@app.post("/doc") -def store_doc( - collection_name: Optional[str] = Form(None), - file: UploadFile = File(...), - user=Depends(get_verified_user), -): - # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" - - log.info(f"file.content_type: {file.content_type}") - try: - unsanitized_filename = file.filename - filename = os.path.basename(unsanitized_filename) - - file_path = f"{UPLOAD_DIR}/{filename}" - - contents = file.file.read() - with open(file_path, "wb") as f: - f.write(contents) - f.close() - - f = open(file_path, "rb") - if collection_name is None: - collection_name = calculate_sha256(f)[:63] - f.close() - - loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, - ) - data = loader.load(filename, file.content_type, file_path) - - try: - result = store_data_in_vector_db(data, collection_name) - - if result: - return { - "status": True, - "collection_name": collection_name, - "filename": filename, - } - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=e, - ) - except Exception as e: - log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - class ProcessFileForm(BaseModel): file_id: str collection_name: Optional[str] = None @@ -796,11 +735,10 @@ def process_file( ) data = loader.load(file.filename, file.meta.get("content_type"), file_path) - f = open(file_path, "rb") collection_name = form_data.collection_name if collection_name is None: - collection_name = calculate_sha256(f)[:63] - f.close() + with open(file_path, "rb") as f: + collection_name = calculate_sha256(f)[:63] try: result = store_data_in_vector_db( @@ -813,11 +751,9 @@ def process_file( ) if result: - return { "status": True, "collection_name": collection_name, - "known_type": known_type, "filename": file.meta.get("name", file.filename), } except Exception as e: @@ -839,15 +775,15 @@ def process_file( ) -class TextRAGForm(BaseModel): +class ProcessTextForm(BaseModel): name: str content: str collection_name: Optional[str] = None -@app.post("/text") -def store_text( - form_data: TextRAGForm, +@app.post("/process/text") +def process_text( + form_data: ProcessTextForm, user=Depends(get_verified_user), ): collection_name = form_data.collection_name @@ -878,9 +814,8 @@ def process_docs_dir(user=Depends(get_admin_user)): filename = path.name file_content_type = mimetypes.guess_type(path) - f = open(path, "rb") - collection_name = calculate_sha256(f)[:63] - f.close() + with open(path, "rb") as f: + collection_name = calculate_sha256(f)[:63] loader = Loader( engine=app.state.config.CONTENT_EXTRACTION_ENGINE, @@ -933,7 +868,7 @@ def process_docs_dir(user=Depends(get_admin_user)): @app.post("/process/youtube") -def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): +def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): try: loader = YoutubeLoader.from_youtube_url( form_data.url, @@ -944,10 +879,11 @@ def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): data = loader.load() collection_name = form_data.collection_name - if collection_name == "": + if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] store_data_in_vector_db(data, collection_name, overwrite=True) + return { "status": True, "collection_name": collection_name, @@ -962,8 +898,7 @@ def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): @app.post("/process/web") -def process_web(form_data: UrlForm, user=Depends(get_verified_user)): - # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" +def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): try: loader = get_web_loader( form_data.url, @@ -973,10 +908,11 @@ def process_web(form_data: UrlForm, user=Depends(get_verified_user)): data = loader.load() collection_name = form_data.collection_name - if collection_name == "": + if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] store_data_in_vector_db(data, collection_name, overwrite=True) + return { "status": True, "collection_name": collection_name,