From 3f5f410453709d377b278037fac9553249020845 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 27 Jun 2024 11:29:59 -0700 Subject: [PATCH] refac --- backend/apps/images/main.py | 8 ++--- backend/apps/openai/main.py | 4 +-- backend/apps/rag/main.py | 20 +++++------ backend/apps/webui/routers/chats.py | 44 ++++++++++++------------- backend/apps/webui/routers/configs.py | 4 +-- backend/apps/webui/routers/documents.py | 8 ++--- backend/apps/webui/routers/prompts.py | 6 ++-- 7 files changed, 47 insertions(+), 47 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 6ec64d280..8f1a08e04 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -16,7 +16,7 @@ from faster_whisper import WhisperModel from constants import ERROR_MESSAGES from utils.utils import ( - get_current_user, + get_verified_user, get_admin_user, ) @@ -258,7 +258,7 @@ async def update_image_size( @app.get("/models") -def get_models(user=Depends(get_current_user)): +def get_models(user=Depends(get_verified_user)): try: if app.state.config.ENGINE == "openai": return [ @@ -347,7 +347,7 @@ def set_model_handler(model: str): @app.post("/models/default/update") def update_default_model( form_data: UpdateModelForm, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): return set_model_handler(form_data.model) @@ -424,7 +424,7 @@ def save_url_image(url): @app.post("/generations") def generate_image( form_data: GenerateImageForm, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 302dd8d98..31dd48741 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -16,7 +16,7 @@ from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( decode_token, - get_current_user, + get_verified_user, get_verified_user, get_admin_user, ) @@ -296,7 +296,7 @@ async def get_all_models(raw: bool = False): @app.get("/models") @app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): +async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): if url_idx == None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6c5b105d8..7c6974535 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -85,7 +85,7 @@ from utils.misc import ( sanitize_filename, extract_folders_after_data_docs, ) -from utils.utils import get_current_user, get_admin_user +from utils.utils import get_verified_user, get_admin_user from config import ( AppConfig, @@ -529,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ @app.get("/template") -async def get_rag_template(user=Depends(get_current_user)): +async def get_rag_template(user=Depends(get_verified_user)): return { "status": True, "template": app.state.config.RAG_TEMPLATE, @@ -586,7 +586,7 @@ class QueryDocForm(BaseModel): @app.post("/query/doc") def query_doc_handler( form_data: QueryDocForm, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): try: if app.state.config.ENABLE_RAG_HYBRID_SEARCH: @@ -626,7 +626,7 @@ class QueryCollectionsForm(BaseModel): @app.post("/query/collection") def query_collection_handler( form_data: QueryCollectionsForm, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): try: if app.state.config.ENABLE_RAG_HYBRID_SEARCH: @@ -657,7 +657,7 @@ def query_collection_handler( @app.post("/youtube") -def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): +def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): try: loader = YoutubeLoader.from_youtube_url( form_data.url, @@ -686,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): @app.post("/web") -def store_web(form_data: UrlForm, user=Depends(get_current_user)): +def store_web(form_data: UrlForm, user=Depends(get_verified_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = get_web_loader( @@ -864,7 +864,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: @app.post("/web/search") -def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): +def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): try: logging.info( f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" @@ -1084,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): def store_doc( collection_name: Optional[str] = Form(None), file: UploadFile = File(...), - user=Depends(get_current_user), + user=Depends(get_verified_user), ): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" @@ -1145,7 +1145,7 @@ class ProcessDocForm(BaseModel): @app.post("/process/doc") def process_doc( form_data: ProcessDocForm, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): try: file = Files.get_file_by_id(form_data.file_id) @@ -1200,7 +1200,7 @@ class TextRAGForm(BaseModel): @app.post("/text") def store_text( form_data: TextRAGForm, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): collection_name = form_data.collection_name diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 9d1cceaa1..c4d6575c2 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -1,7 +1,7 @@ from fastapi import Depends, Request, HTTPException, status from datetime import datetime, timedelta from typing import List, Union, Optional -from utils.utils import get_current_user, get_admin_user +from utils.utils import get_verified_user, get_admin_user from fastapi import APIRouter from pydantic import BaseModel import json @@ -43,7 +43,7 @@ router = APIRouter() @router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/list", response_model=List[ChatTitleIdResponse]) async def get_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 + user=Depends(get_verified_user), skip: int = 0, limit: int = 50 ): return Chats.get_chat_list_by_user_id(user.id, skip, limit) @@ -54,7 +54,7 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) -async def delete_all_user_chats(request: Request, user=Depends(get_current_user)): +async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): if ( user.role == "user" @@ -89,7 +89,7 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): +async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): try: chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -106,7 +106,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): @router.get("/all", response_model=List[ChatResponse]) -async def get_user_chats(user=Depends(get_current_user)): +async def get_user_chats(user=Depends(get_verified_user)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) for chat in Chats.get_chats_by_user_id(user.id) @@ -119,7 +119,7 @@ async def get_user_chats(user=Depends(get_current_user)): @router.get("/all/archived", response_model=List[ChatResponse]) -async def get_user_chats(user=Depends(get_current_user)): +async def get_user_chats(user=Depends(get_verified_user)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) for chat in Chats.get_archived_chats_by_user_id(user.id) @@ -151,7 +151,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): @router.get("/archived", response_model=List[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 + user=Depends(get_verified_user), skip: int = 0, limit: int = 50 ): return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) @@ -162,7 +162,7 @@ async def get_archived_session_user_chat_list( @router.post("/archive/all", response_model=bool) -async def archive_all_chats(user=Depends(get_current_user)): +async def archive_all_chats(user=Depends(get_verified_user)): return Chats.archive_all_chats_by_user_id(user.id) @@ -172,7 +172,7 @@ async def archive_all_chats(user=Depends(get_current_user)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): +async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): if user.role == "pending": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -204,7 +204,7 @@ class TagNameForm(BaseModel): @router.post("/tags", response_model=List[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, user=Depends(get_current_user) + form_data: TagNameForm, user=Depends(get_verified_user) ): print(form_data) @@ -229,7 +229,7 @@ async def get_user_chat_list_by_tag_name( @router.get("/tags/all", response_model=List[TagModel]) -async def get_all_tags(user=Depends(get_current_user)): +async def get_all_tags(user=Depends(get_verified_user)): try: tags = Tags.get_tags_by_user_id(user.id) return tags @@ -246,7 +246,7 @@ async def get_all_tags(user=Depends(get_current_user)): @router.get("/{id}", response_model=Optional[ChatResponse]) -async def get_chat_by_id(id: str, user=Depends(get_current_user)): +async def get_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: @@ -264,7 +264,7 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): @router.post("/{id}", response_model=Optional[ChatResponse]) async def update_chat_by_id( - id: str, form_data: ChatForm, user=Depends(get_current_user) + id: str, form_data: ChatForm, user=Depends(get_verified_user) ): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: @@ -285,7 +285,7 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) -async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): +async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): if user.role == "admin": result = Chats.delete_chat_by_id(id) @@ -307,7 +307,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ @router.get("/{id}/clone", response_model=Optional[ChatResponse]) -async def clone_chat_by_id(id: str, user=Depends(get_current_user)): +async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: @@ -333,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)): @router.get("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id(id: str, user=Depends(get_current_user)): +async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) @@ -350,7 +350,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)): @router.post("/{id}/share", response_model=Optional[ChatResponse]) -async def share_chat_by_id(id: str, user=Depends(get_current_user)): +async def share_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if chat.share_id: @@ -382,7 +382,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)): @router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): +async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if not chat.share_id: @@ -405,7 +405,7 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): @router.get("/{id}/tags", response_model=List[TagModel]) -async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): +async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) if tags != None: @@ -423,7 +423,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) async def add_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) + id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) ): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) @@ -450,7 +450,7 @@ async def add_chat_tag_by_id( @router.delete("/{id}/tags", response_model=Optional[bool]) async def delete_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) + id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) ): result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( form_data.tag_name, id, user.id @@ -470,7 +470,7 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): +async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) if result: diff --git a/backend/apps/webui/routers/configs.py b/backend/apps/webui/routers/configs.py index c127e721b..39e435013 100644 --- a/backend/apps/webui/routers/configs.py +++ b/backend/apps/webui/routers/configs.py @@ -14,7 +14,7 @@ from apps.webui.models.users import Users from utils.utils import ( get_password_hash, - get_current_user, + get_verified_user, get_admin_user, create_token, ) @@ -84,6 +84,6 @@ async def set_banners( @router.get("/banners", response_model=List[BannerModel]) async def get_banners( request: Request, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): return request.app.state.config.BANNERS diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index 311455390..dc53b5246 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -14,7 +14,7 @@ from apps.webui.models.documents import ( DocumentResponse, ) -from utils.utils import get_current_user, get_admin_user +from utils.utils import get_verified_user, get_admin_user from constants import ERROR_MESSAGES router = APIRouter() @@ -25,7 +25,7 @@ router = APIRouter() @router.get("/", response_model=List[DocumentResponse]) -async def get_documents(user=Depends(get_current_user)): +async def get_documents(user=Depends(get_verified_user)): docs = [ DocumentResponse( **{ @@ -74,7 +74,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): @router.get("/doc", response_model=Optional[DocumentResponse]) -async def get_doc_by_name(name: str, user=Depends(get_current_user)): +async def get_doc_by_name(name: str, user=Depends(get_verified_user)): doc = Documents.get_doc_by_name(name) if doc: @@ -106,7 +106,7 @@ class TagDocumentForm(BaseModel): @router.post("/doc/tags", response_model=Optional[DocumentResponse]) -async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): +async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)): doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) if doc: diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index 47d8c7012..e609a0a1b 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -8,7 +8,7 @@ import json from apps.webui.models.prompts import Prompts, PromptForm, PromptModel -from utils.utils import get_current_user, get_admin_user +from utils.utils import get_verified_user, get_admin_user from constants import ERROR_MESSAGES router = APIRouter() @@ -19,7 +19,7 @@ router = APIRouter() @router.get("/", response_model=List[PromptModel]) -async def get_prompts(user=Depends(get_current_user)): +async def get_prompts(user=Depends(get_verified_user)): return Prompts.get_prompts() @@ -52,7 +52,7 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)) @router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command(command: str, user=Depends(get_current_user)): +async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: