From 70d2571be117583368adc2988583c89e218551b5 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 7 Jan 2024 02:46:12 -0800 Subject: [PATCH] feat: rag backend auth --- backend/apps/rag/main.py | 63 +++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index ff4bbb6ce..02caf12b7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -24,6 +24,8 @@ from typing import Optional import uuid + +from utils.utils import get_current_user from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES @@ -84,7 +86,12 @@ async def get_status(): @app.get("/query/{collection_name}") -def query_collection(collection_name: str, query: str, k: Optional[int] = 4): +def query_collection( + collection_name: str, + query: str, + k: Optional[int] = 4, + user=Depends(get_current_user), +): try: collection = CHROMA_CLIENT.get_collection( name=collection_name, @@ -101,7 +108,7 @@ def query_collection(collection_name: str, query: str, k: Optional[int] = 4): @app.post("/web") -def store_web(form_data: StoreWebForm): +def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = WebBaseLoader(form_data.url) @@ -117,7 +124,11 @@ def store_web(form_data: StoreWebForm): @app.post("/doc") -def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): +def store_doc( + collection_name: str = Form(...), + file: UploadFile = File(...), + user=Depends(get_current_user), +): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" file.filename = f"{collection_name}-{file.filename}" @@ -159,26 +170,38 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): @app.get("/reset/db") -def reset_vector_db(): - CHROMA_CLIENT.reset() +def reset_vector_db(user=Depends(get_current_user)): + if user.role == "admin": + CHROMA_CLIENT.reset() + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) @app.get("/reset") -def reset(): - folder = f"{UPLOAD_DIR}" - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) +def reset(user=Depends(get_current_user)): + if user.role == "admin": + folder = f"{UPLOAD_DIR}" + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) + CHROMA_CLIENT.reset() except Exception as e: - print("Failed to delete %s. Reason: %s" % (file_path, e)) + print(e) - try: - CHROMA_CLIENT.reset() - except Exception as e: - print(e) - - return {"status": True} + return {"status": True} + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + )