diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 3c37bb09b..6fd541f4e 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -53,7 +53,9 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) +SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False +) Base = declarative_base() @@ -66,4 +68,3 @@ def get_session(): except Exception as e: db.rollback() raise e - diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index fd2934bb1..9f10e0fdd 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -126,9 +126,7 @@ class AuthsTable: else: return None - def authenticate_user( - self, email: str, password: str - ) -> Optional[UserModel]: + def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") with get_session() as db: try: @@ -144,9 +142,7 @@ class AuthsTable: except: return None - def authenticate_user_by_api_key( - self, api_key: str - ) -> Optional[UserModel]: + def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") with get_session() as db: # if no api_key, return None @@ -159,9 +155,7 @@ class AuthsTable: except: return False - def authenticate_user_by_trusted_header( - self, email: str - ) -> Optional[UserModel]: + def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") with get_session() as db: try: @@ -172,12 +166,12 @@ class AuthsTable: except: return None - def update_user_password_by_id( - self, id: str, new_password: str - ) -> bool: + def update_user_password_by_id(self, id: str, new_password: str) -> bool: with get_session() as db: try: - result = db.query(Auth).filter_by(id=id).update({"password": new_password}) + result = ( + db.query(Auth).filter_by(id=id).update({"password": new_password}) + ) return True if result == 1 else False except: return False diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d71ffd992..b0c983ade 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def insert_new_chat( - self, user_id: str, form_data: ChatForm - ) -> Optional[ChatModel]: + def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: with get_session() as db: id = str(uuid.uuid4()) chat = ChatModel( @@ -89,7 +87,9 @@ class ChatTable: "id": id, "user_id": user_id, "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" ), "chat": json.dumps(form_data.chat), "created_at": int(time.time()), @@ -103,9 +103,7 @@ class ChatTable: db.refresh(result) return ChatModel.model_validate(result) if result else None - def update_chat_by_id( - self, id: str, chat: dict - ) -> Optional[ChatModel]: + def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: with get_session() as db: try: chat_obj = db.get(Chat, id) @@ -119,9 +117,7 @@ class ChatTable: except Exception as e: return None - def insert_shared_chat_by_chat_id( - self, chat_id: str - ) -> Optional[ChatModel]: + def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_session() as db: # Get the existing chat to share chat = db.get(Chat, chat_id) @@ -145,14 +141,14 @@ class ChatTable: db.refresh(shared_result) # Update the original chat with the share_id result = ( - db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id}) + db.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) ) return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id( - self, chat_id: str - ) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_session() as db: try: print("update_shared_chat_by_id") @@ -271,9 +267,7 @@ class ChatTable: except Exception as e: return None - def get_chat_by_id_and_user_id( - self, id: str, user_id: str - ) -> Optional[ChatModel]: + def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: with get_session() as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() @@ -293,13 +287,13 @@ class ChatTable: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: with get_session() as db: all_chats = ( - db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) + db.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id( - self, user_id: str - ) -> List[ChatModel]: + def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: with get_session() as db: all_chats = ( db.query(Chat) diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 6348967db..897f182be 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -106,7 +106,9 @@ class DocumentsTable: def get_docs(self) -> List[DocumentModel]: with get_session() as db: - return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] + return [ + DocumentModel.model_validate(doc) for doc in db.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index d2565db3d..b7196d604 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -39,6 +39,7 @@ class FileModel(BaseModel): model_config = ConfigDict(from_attributes=True) + #################### # Forms #################### diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 417e52329..2343c9139 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -142,9 +142,9 @@ class FunctionsTable: with get_session() as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function).filter_by( - type=type, is_active=True - ).all() + for function in db.query(Function) + .filter_by(type=type, is_active=True) + .all() ] else: with get_session() as db: @@ -220,10 +220,12 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: with get_session() as db: - db.query(Function).filter_by(id=id).update({ - **updated, - "updated_at": int(time.time()), - }) + db.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) db.commit() return self.get_function_by_id(id) except: @@ -232,10 +234,12 @@ class FunctionsTable: def deactivate_all_functions(self) -> Optional[bool]: try: with get_session() as db: - db.query(Function).update({ - "is_active": False, - "updated_at": int(time.time()), - }) + db.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) db.commit() return True except: diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 7641ee5a0..86b4fa49b 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -153,9 +153,7 @@ class ModelsTable: except: return None - def update_model_by_id( - self, id: str, model: ModelForm - ) -> Optional[ModelModel]: + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: # update only the fields that are present in the model with get_session() as db: diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index 2157153d8..029fd5e1b 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -83,7 +83,9 @@ class PromptsTable: def get_prompts(self) -> List[PromptModel]: with get_session() as db: - return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] + return [ + PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 5ad176c37..dfe63688e 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel): class TagTable: - def insert_new_tag( - self, name: str, user_id: str - ) -> Optional[TagModel]: + def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: @@ -201,11 +199,13 @@ class TagTable: self, tag_name: str, user_id: str ) -> int: with get_session() as db: - return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() + return ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) - def delete_tag_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> bool: + def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: with get_session() as db: res = ( @@ -252,9 +252,7 @@ class TagTable: log.error(f"delete_tag: {e}") return False - def delete_tags_by_chat_id_and_user_id( - self, chat_id: str, user_id: str - ) -> bool: + def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) for tag in tags: diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index bef15185b..796892927 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -165,9 +165,7 @@ class UsersTable: except: return None - def update_user_role_by_id( - self, id: str, role: str - ) -> Optional[UserModel]: + def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: with get_session() as db: try: db.query(User).filter_by(id=id).update({"role": role}) @@ -193,12 +191,12 @@ class UsersTable: except: return None - def update_user_last_active_by_id( - self, id: str - ) -> Optional[UserModel]: + def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: with get_session() as db: try: - db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) + db.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) @@ -217,9 +215,7 @@ class UsersTable: except: return None - def update_user_by_id( - self, id: str, updated: dict - ) -> Optional[UserModel]: + def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: with get_session() as db: try: db.query(User).filter_by(id=id).update(updated) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index f32b074b1..1be79d259 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -78,8 +78,7 @@ async def get_session_user( @router.post("/update/profile", response_model=UserResponse) async def update_profile( - form_data: UpdateProfileForm, - session_user=Depends(get_current_user) + form_data: UpdateProfileForm, session_user=Depends(get_current_user) ): if session_user: user = Users.update_user_by_id( @@ -101,8 +100,7 @@ async def update_profile( @router.post("/update/password", response_model=bool) async def update_password( - form_data: UpdatePasswordForm, - session_user=Depends(get_current_user) + form_data: UpdatePasswordForm, session_user=Depends(get_current_user) ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) @@ -269,9 +267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.post("/add", response_model=SigninResponse) -async def add_user( - form_data: AddUserForm, user=Depends(get_admin_user) -): +async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): if not validate_email_format(form_data.email.lower()): raise HTTPException( @@ -316,9 +312,7 @@ async def add_user( @router.get("/admin/details") -async def get_admin_details( - request: Request, user=Depends(get_current_user) -): +async def get_admin_details(request: Request, user=Depends(get_current_user)): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 8b2b9987a..3070483f3 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -55,9 +55,7 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) -async def delete_all_user_chats( - request: Request, user=Depends(get_current_user) -): +async def delete_all_user_chats(request: Request, user=Depends(get_current_user)): if ( user.role == "user" @@ -95,9 +93,7 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat( - form_data: ChatForm, user=Depends(get_current_user) -): +async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): try: chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -180,9 +176,7 @@ async def archive_all_chats(user=Depends(get_current_user)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id( - share_id: str, user=Depends(get_current_user) -): +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): if user.role == "pending": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -225,9 +219,7 @@ async def get_user_chat_list_by_tag_name( ) ] - chats = Chats.get_chat_list_by_chat_ids( - chat_ids, form_data.skip, form_data.limit - ) + chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) if len(chats) == 0: Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) @@ -297,9 +289,7 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) -async def delete_chat_by_id( - request: Request, id: str, user=Depends(get_current_user) -): +async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): if user.role == "admin": result = Chats.delete_chat_by_id(id) @@ -347,9 +337,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)): @router.get("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id( - id: str, user=Depends(get_current_user) -): +async def archive_chat_by_id(id: str, user=Depends(get_current_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) @@ -398,9 +386,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)): @router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id( - id: str, user=Depends(get_current_user) -): +async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if not chat.share_id: @@ -423,9 +409,7 @@ async def delete_shared_chat_by_id( @router.get("/{id}/tags", response_model=List[TagModel]) -async def get_chat_tags_by_id( - id: str, user=Depends(get_current_user) -): +async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) if tags != None: @@ -443,9 +427,7 @@ async def get_chat_tags_by_id( @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) async def add_chat_tag_by_id( - id: str, - form_data: ChatIdTagForm, - user=Depends(get_current_user) + id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) ): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) @@ -494,9 +476,7 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_chat_tags_by_id( - id: str, user=Depends(get_current_user) -): +async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) if result: diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index f358e033c..4e1111c07 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -44,9 +44,7 @@ async def get_documents(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[DocumentResponse]) -async def create_new_doc( - form_data: DocumentForm, user=Depends(get_admin_user) -): +async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): doc = Documents.get_doc_by_name(form_data.name) if doc == None: doc = Documents.insert_new_doc(user.id, form_data) @@ -76,9 +74,7 @@ async def create_new_doc( @router.get("/doc", response_model=Optional[DocumentResponse]) -async def get_doc_by_name( - name: str, user=Depends(get_current_user) -): +async def get_doc_by_name(name: str, user=Depends(get_current_user)): doc = Documents.get_doc_by_name(name) if doc: @@ -110,12 +106,8 @@ class TagDocumentForm(BaseModel): @router.post("/doc/tags", response_model=Optional[DocumentResponse]) -async def tag_doc_by_name( - form_data: TagDocumentForm, user=Depends(get_current_user) -): - doc = Documents.update_doc_content_by_name( - form_data.name, {"tags": form_data.tags} - ) +async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): + doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) if doc: return DocumentResponse( @@ -163,8 +155,6 @@ async def update_doc_by_name( @router.delete("/doc/delete", response_model=bool) -async def delete_doc_by_name( - name: str, user=Depends(get_admin_user) -): +async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): result = Documents.delete_doc_by_name(name) return result diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index e98d1da58..fffe0743c 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -50,10 +50,7 @@ router = APIRouter() @router.post("/") -def upload_file( - file: UploadFile = File(...), - user=Depends(get_verified_user) -): +def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index d6b2d0fcb..2c473ebe8 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -167,9 +167,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)): @router.delete("/{memory_id}", response_model=bool) -async def delete_memory_by_id( - memory_id: str, user=Depends(get_verified_user) -): +async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index 3912b1028..0cbf3d366 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -29,9 +29,7 @@ async def get_prompts(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt( - form_data: PromptForm, user=Depends(get_admin_user) -): +async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): prompt = Prompts.get_prompt_by_command(form_data.command) if prompt == None: prompt = Prompts.insert_new_prompt(user.id, form_data) @@ -54,9 +52,7 @@ async def create_new_prompt( @router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command( - command: str, user=Depends(get_current_user) -): +async def get_prompt_by_command(command: str, user=Depends(get_current_user)): prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: @@ -95,8 +91,6 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command( - command: str, user=Depends(get_admin_user) -): +async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 82a09477d..ea9db8180 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -180,9 +180,7 @@ async def update_toolkit_by_id( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id( - request: Request, id: str, user=Depends(get_admin_user) -): +async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): result = Tools.delete_tool_by_id(id) if result: diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index 8a38d5b9f..9627f0b06 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -40,9 +40,7 @@ router = APIRouter() @router.get("/", response_model=List[UserModel]) -async def get_users( - skip: int = 0, limit: int = 50, user=Depends(get_admin_user) -): +async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): return Users.get_users(skip, limit) @@ -70,9 +68,7 @@ async def update_user_permissions( @router.post("/update/role", response_model=Optional[UserModel]) -async def update_user_role( - form_data: UserRoleUpdateForm, user=Depends(get_admin_user) -): +async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): if user.id != form_data.id and form_data.id != Users.get_first_user().id: return Users.update_user_role_by_id(form_data.id, form_data.role) @@ -89,9 +85,7 @@ async def update_user_role( @router.get("/user/settings", response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user( - user=Depends(get_verified_user) -): +async def get_user_settings_by_session_user(user=Depends(get_verified_user)): user = Users.get_user_by_id(user.id) if user: return user.settings @@ -127,9 +121,7 @@ async def update_user_settings_by_session_user( @router.get("/user/info", response_model=Optional[dict]) -async def get_user_info_by_session_user( - user=Depends(get_verified_user) -): +async def get_user_info_by_session_user(user=Depends(get_verified_user)): user = Users.get_user_by_id(user.id) if user: return user.info @@ -154,9 +146,7 @@ async def update_user_info_by_session_user( if user.info is None: user.info = {} - user = Users.update_user_by_id( - user.id, {"info": {**user.info, **form_data}} - ) + user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) if user: return user.info else: @@ -182,9 +172,7 @@ class UserResponse(BaseModel): @router.get("/{user_id}", response_model=UserResponse) -async def get_user_by_id( - user_id: str, user=Depends(get_verified_user) -): +async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): # Check if user_id is a shared chat # If it is, get the user_id from the chat @@ -267,9 +255,7 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id( - user_id: str, user=Depends(get_admin_user) -): +async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): if user.id != user_id: result = Auths.delete_auth_by_id(user_id) diff --git a/backend/main.py b/backend/main.py index 2c4d5ecfd..f35095bf1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -175,7 +175,9 @@ https://github.com/open-webui/open-webui def run_migrations(): env = os.environ.copy() env["DATABASE_URL"] = DATABASE_URL - migration_task = subprocess.run(["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env) + migration_task = subprocess.run( + ["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env + ) if migration_task.returncode > 0: raise ValueError("Error running migrations") diff --git a/backend/migrations/versions/ba76b0bae648_init.py b/backend/migrations/versions/ba76b0bae648_init.py index b1250662f..c491ed46c 100644 --- a/backend/migrations/versions/ba76b0bae648_init.py +++ b/backend/migrations/versions/ba76b0bae648_init.py @@ -5,6 +5,7 @@ Revises: Create Date: 2024-06-24 09:09:11.636336 """ + from typing import Sequence, Union from alembic import op @@ -13,7 +14,7 @@ import apps.webui.internal.db # revision identifiers, used by Alembic. -revision: str = 'ba76b0bae648' +revision: str = "ba76b0bae648" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,141 +22,153 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('auth', - sa.Column('id', sa.String(), nullable=False), - sa.Column('email', sa.String(), nullable=True), - sa.Column('password', sa.String(), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "auth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=True), + sa.Column("password", sa.String(), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('chat', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('chat', sa.String(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('share_id', sa.String(), nullable=True), - sa.Column('archived', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('share_id') + op.create_table( + "chat", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("chat", sa.String(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("share_id", sa.String(), nullable=True), + sa.Column("archived", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("share_id"), ) - op.create_table('chatidtag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('tag_name', sa.String(), nullable=True), - sa.Column('chat_id', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "chatidtag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("tag_name", sa.String(), nullable=True), + sa.Column("chat_id", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('document', - sa.Column('collection_name', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('filename', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('collection_name'), - sa.UniqueConstraint('name') + op.create_table( + "document", + sa.Column("collection_name", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("filename", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("collection_name"), + sa.UniqueConstraint("name"), ) - op.create_table('file', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('filename', sa.String(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "file", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("filename", sa.String(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('function', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('type', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "function", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('memory', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "memory", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('model', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('base_model_id', sa.String(), nullable=True), - sa.Column('name', sa.String(), nullable=True), - sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "model", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("base_model_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('prompt', - sa.Column('command', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('command') + op.create_table( + "prompt", + sa.Column("command", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("command"), ) - op.create_table('tag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('data', sa.String(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("data", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('tool', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tool", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('user', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('email', sa.String(), nullable=True), - sa.Column('role', sa.String(), nullable=True), - sa.Column('profile_image_url', sa.String(), nullable=True), - sa.Column('last_active_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('api_key', sa.String(), nullable=True), - sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('api_key') + op.create_table( + "user", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("profile_image_url", sa.String(), nullable=True), + sa.Column("last_active_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("api_key"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('user') - op.drop_table('tool') - op.drop_table('tag') - op.drop_table('prompt') - op.drop_table('model') - op.drop_table('memory') - op.drop_table('function') - op.drop_table('file') - op.drop_table('document') - op.drop_table('chatidtag') - op.drop_table('chat') - op.drop_table('auth') + op.drop_table("user") + op.drop_table("tool") + op.drop_table("tag") + op.drop_table("prompt") + op.drop_table("model") + op.drop_table("memory") + op.drop_table("function") + op.drop_table("file") + op.drop_table("document") + op.drop_table("chatidtag") + op.drop_table("chat") + op.drop_table("auth") # ### end Alembic commands ### diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index 9cbf42d47..781fbfff8 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -91,6 +91,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): while retries > 0: try: from config import BACKEND_DIR + db = create_engine(database_url, pool_pre_ping=True) db = db.connect() log.info("postgres is ready!")