From 4e433d9015b2d744bc0efdc504d2a8865f0bc5e1 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 3 Jul 2024 18:18:33 +0100 Subject: [PATCH 01/24] wip: citations via __event_emitter__ --- src/lib/components/chat/Chat.svelte | 32 ++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 3d03246b7..87bd9b4de 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -132,15 +132,33 @@ console.log(data); let message = history.messages[data.message_id]; - const status = { - done: data?.data?.done ?? null, - description: data?.data?.status ?? null - }; + const type = data?.data?.type ?? null; + if (type === "status") { + const status = { + done: data?.data?.done ?? null, + description: data?.data?.status ?? null + }; - if (message.statusHistory) { - message.statusHistory.push(status); + if (message.statusHistory) { + message.statusHistory.push(status); + } else { + message.statusHistory = [status]; + } + } else if (type === "citation") { + console.log(data); + const citation = { + document: data?.data?.document ?? null, + metadata: data?.data?.metadata ?? null, + source: data?.data?.source ?? null + }; + + if (message.citations) { + message.citations.push(citation); + } else { + message.citations = [citation]; + } } else { - message.statusHistory = [status]; + console.log("Unknown message type", data); } messages = messages; From f6dcffab135bc0be3036970ab1bec6c9ed70a91d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 21:18:40 -0700 Subject: [PATCH 02/24] fix: pinned chat delete issue --- src/lib/components/layout/Sidebar.svelte | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 193fe41ff..e6cd45c1f 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -186,6 +186,7 @@ goto('/'); } await chats.set(await getChatList(localStorage.token)); + await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); } }; From 864646094e248d1ee3ed9f09e12312ec241b3217 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 23:32:39 -0700 Subject: [PATCH 03/24] refac --- backend/apps/webui/models/auths.py | 20 +- backend/apps/webui/models/chats.py | 322 ++++++++++++++----------- backend/apps/webui/models/documents.py | 96 ++++---- backend/apps/webui/models/files.py | 82 ++++--- backend/apps/webui/models/functions.py | 184 +++++++------- backend/apps/webui/models/memories.py | 130 +++++----- backend/apps/webui/models/models.py | 53 ++-- backend/apps/webui/models/prompts.py | 54 +++-- backend/apps/webui/models/tags.py | 212 ++++++++-------- backend/apps/webui/models/tools.py | 89 ++++--- backend/apps/webui/models/users.py | 163 +++++++------ 11 files changed, 789 insertions(+), 616 deletions(-) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 560d9a686..48c8b543a 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -7,7 +7,7 @@ from sqlalchemy import String, Column, Boolean, Text from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -110,14 +110,14 @@ class AuthsTable: **{"id": id, "email": email, "password": password, "active": True} ) result = Auth(**auth.model_dump()) - Session.add(result) + db.add(result) user = Users.insert_new_user( id, name, email, profile_image_url, role, oauth_sub ) - Session.commit() - Session.refresh(result) + db.commit() + db.refresh(result) if result and user: return user @@ -127,7 +127,7 @@ class AuthsTable: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = Session.query(Auth).filter_by(email=email, active=True).first() + auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: if verify_password(password, auth.password): user = Users.get_user_by_id(auth.id) @@ -154,7 +154,7 @@ class AuthsTable: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = Session.query(Auth).filter(email=email, active=True).first() + auth = db.query(Auth).filter(email=email, active=True).first() if auth: user = Users.get_user_by_id(auth.id) return user @@ -163,16 +163,14 @@ class AuthsTable: def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - result = ( - Session.query(Auth).filter_by(id=id).update({"password": new_password}) - ) + result = db.query(Auth).filter_by(id=id).update({"password": new_password}) return True if result == 1 else False except: return False def update_email_by_id(self, id: str, email: str) -> bool: try: - result = Session.query(Auth).filter_by(id=id).update({"email": email}) + result = db.query(Auth).filter_by(id=id).update({"email": email}) return True if result == 1 else False except: return False @@ -183,7 +181,7 @@ class AuthsTable: result = Users.delete_user_by_id(id) if result: - Session.query(Auth).filter_by(id=id).delete() + db.query(Auth).filter_by(id=id).delete() return True else: diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d6829ee7b..8d2e6b104 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -7,7 +7,7 @@ import time from sqlalchemy import Column, String, BigInteger, Boolean, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db #################### @@ -79,87 +79,99 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) + with get_db() as db: - result = Chat(**chat.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - return ChatModel.model_validate(result) if result else None + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: - chat_obj = Session.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) - Session.commit() - Session.refresh(chat_obj) + with get_db() as db: - return ChatModel.model_validate(chat_obj) + chat_obj = db.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + db.commit() + db.refresh(chat_obj) + + return ChatModel.model_validate(chat_obj) except Exception as e: return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - # Get the existing chat to share - chat = Session.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - Session.add(shared_result) - Session.commit() - Session.refresh(shared_result) - # Update the original chat with the share_id - result = ( - Session.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + with get_db() as db: - return shared_chat if (shared_result and result) else None + # Get the existing chat to share + chat = db.get(Chat, chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + db.add(shared_result) + db.commit() + db.refresh(shared_result) + # Update the original chat with the share_id + result = ( + db.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) + ) + + return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: - print("update_shared_chat_by_id") - chat = Session.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - Session.commit() - Session.refresh(chat) + with get_db() as db: - return self.get_chat_by_id(chat.share_id) + print("update_shared_chat_by_id") + chat = db.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + db.commit() + db.refresh(chat) + + return self.get_chat_by_id(chat.share_id) except: return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + return True except: return False @@ -167,42 +179,50 @@ class ChatTable: self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.share_id = share_id - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.share_id = share_id + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.archived = not chat.archived - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.archived = not chat.archived + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + return True except: return False def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, @@ -211,110 +231,136 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - query = Session.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + return ChatModel.model_validate(chat) except: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(share_id=id).first() + with get_db() as db: - if chat: - return self.get_chat_by_id(id) - else: - return None + chat = db.query(Chat).filter_by(share_id=id).first() + + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def delete_chat_by_id(self, id: str) -> bool: try: - Session.query(Chat).filter_by(id=id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - Session.query(Chat).filter_by(id=id, user_id=user_id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id, user_id=user_id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: - self.delete_shared_chats_by_user_id(user_id) - Session.query(Chat).filter_by(user_id=user_id).delete() - return True + with get_db() as db: + + self.delete_shared_chats_by_user_id(user_id) + + db.query(Chat).filter_by(user_id=user_id).delete() + return True except: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + with get_db() as db: - return True + chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + + db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + + return True except: return False diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 1b69d44a5..16145c4ac 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -74,51 +74,59 @@ class DocumentsTable: def insert_new_doc( self, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: - document = DocumentModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "timestamp": int(time.time()), - } - ) + with get_db() as db: - try: - result = Document(**document.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: + document = DocumentModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "timestamp": int(time.time()), + } + ) + + try: + result = Document(**document.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None + except: return None - except: - return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - document = Session.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + with get_db() as db: + + document = db.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None def get_docs(self) -> List[DocumentModel]: - return [ - DocumentModel.model_validate(doc) for doc in Session.query(Document).all() - ] + with get_db() as db: + + return [ + DocumentModel.model_validate(doc) for doc in db.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - Session.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(form_data.name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None @@ -131,22 +139,26 @@ class DocumentsTable: doc_content = json.loads(doc.content if doc.content else "{}") doc_content = {**doc_content, **updated} - Session.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None def delete_doc_by_name(self, name: str) -> bool: try: - Session.query(Document).filter_by(name=name).delete() - return True + with get_db() as db: + + db.query(Document).filter_by(name=name).delete() + return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index ce904215d..58058e907 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db import json @@ -61,50 +61,62 @@ class FileForm(BaseModel): class FilesTable: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - file = FileModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "created_at": int(time.time()), - } - ) + with get_db() as db: - try: - result = File(**file.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FileModel.model_validate(result) - else: + file = FileModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + } + ) + + try: + result = File(**file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_file_by_id(self, id: str) -> Optional[FileModel]: - try: - file = Session.get(File, id) - return FileModel.model_validate(file) - except: - return None + with get_db() as db: + + try: + file = db.get(File, id) + return FileModel.model_validate(file) + except: + return None def get_files(self) -> List[FileModel]: - return [FileModel.model_validate(file) for file in Session.query(File).all()] + with get_db() as db: + + return [FileModel.model_validate(file) for file in db.query(File).all()] def delete_file_by_id(self, id: str) -> bool: - try: - Session.query(File).filter_by(id=id).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).filter_by(id=id).delete() + return True + except: + return False def delete_all_files(self) -> bool: - try: - Session.query(File).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).delete() + return True + except: + return False Files = FilesTable() diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 64ed4f3cc..5718833d3 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db from apps.webui.models.users import Users import json @@ -91,6 +91,7 @@ class FunctionsTable: def insert_new_function( self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: + function = FunctionModel( **{ **form_data.model_dump(), @@ -102,85 +103,99 @@ class FunctionsTable: ) try: - result = Function(**function.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + with get_db() as db: + result = Function(**function.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - function = Session.get(Function, id) - return FunctionModel.model_validate(function) + with get_db() as db: + + function = db.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(is_active=True).all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(is_active=True).all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type=type, is_active=True) - .all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(type=type).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type=type, is_active=True) + .all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(type=type).all() + ] def get_global_filter_functions(self) -> List[FunctionModel]: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type="filter", is_active=True, is_global=True) - .all() - ] + with get_db() as db: + + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type="filter", is_active=True, is_global=True) + .all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: - try: - function = Session.get(Function, id) - return function.valves if function.valves else {} - except Exception as e: - print(f"An error occurred: {e}") - return None + with get_db() as db: + + try: + function = db.get(Function, id) + return function.valves if function.valves else {} + except Exception as e: + print(f"An error occurred: {e}") + return None def update_function_valves_by_id( self, id: str, valves: dict ) -> Optional[FunctionValves]: - try: - function = Session.get(Function, id) - function.valves = valves - function.updated_at = int(time.time()) - Session.commit() - Session.refresh(function) - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + function = db.get(Function, id) + function.valves = valves + function.updated_at = int(time.time()) + db.commit() + db.refresh(function) + return self.get_function_by_id(id) + except: + return None def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() @@ -199,6 +214,7 @@ class FunctionsTable: def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() @@ -220,37 +236,43 @@ class FunctionsTable: return None def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - try: - Session.query(Function).filter_by(id=id).update( - { - **updated, - "updated_at": int(time.time()), - } - ) - Session.commit() - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_function_by_id(id) + except: + return None def deactivate_all_functions(self) -> Optional[bool]: - try: - Session.query(Function).update( - { - "is_active": False, - "updated_at": int(time.time()), - } - ) - Session.commit() - return True - except: - return None + with get_db() as db: + + try: + db.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) + db.commit() + return True + except: + return None def delete_function_by_id(self, id: str) -> bool: - try: - Session.query(Function).filter_by(id=id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).delete() + return True + except: + return False Functions = FunctionsTable() diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 1f03318fd..662bbedfe 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -3,7 +3,7 @@ from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import time import uuid @@ -45,82 +45,98 @@ class MemoriesTable: user_id: str, content: str, ) -> Optional[MemoryModel]: - id = str(uuid.uuid4()) - memory = MemoryModel( - **{ - "id": id, - "user_id": user_id, - "content": content, - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) - result = Memory(**memory.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + with get_db() as db: + id = str(uuid.uuid4()) + + memory = MemoryModel( + **{ + "id": id, + "user_id": user_id, + "content": content, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + result = Memory(**memory.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, id: str, content: str, ) -> Optional[MemoryModel]: - try: - Session.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_memory_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + db.commit() + return self.get_memory_by_id(id) + except: + return None def get_memories(self) -> List[MemoryModel]: - try: - memories = Session.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: - try: - memories = Session.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - try: - memory = Session.get(Memory, id) - return MemoryModel.model_validate(memory) - except: - return None + with get_db() as db: + + try: + memory = db.get(Memory, id) + return MemoryModel.model_validate(memory) + except: + return None def delete_memory_by_id(self, id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id).delete() - return True + with get_db() as db: - except: - return False + try: + db.query(Memory).filter_by(id=id).delete() + return True + + except: + return False def delete_memories_by_user_id(self, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(user_id=user_id).delete() + return True + except: + return False def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id, user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id, user_id=user_id).delete() + return True + except: + return False Memories = MemoriesTable() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 6543edefc..c95c36c7d 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -5,7 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -126,39 +126,46 @@ class ModelsTable: } ) try: - result = Model(**model.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Model(**model.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None def get_all_models(self) -> List[ModelModel]: - return [ - ModelModel.model_validate(model) for model in Session.query(Model).all() - ] + with get_db() as db: + + return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - model = Session.get(Model, id) - return ModelModel.model_validate(model) + with get_db() as db: + + model = db.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: - # update only the fields that are present in the model - model = Session.query(Model).get(id) - model.update(**model.model_dump()) - Session.commit() - Session.refresh(model) - return ModelModel.model_validate(model) + with get_db() as db: + + # update only the fields that are present in the model + model = db.query(Model).get(id) + model.update(**model.model_dump()) + db.commit() + db.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) @@ -166,8 +173,10 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - Session.query(Model).filter_by(id=id).delete() - return True + with get_db() as db: + + db.query(Model).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index ab8cc04ce..2af2ce22c 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -4,7 +4,7 @@ import time from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -60,46 +60,56 @@ class PromptsTable: ) try: - result = Prompt(**prompt.dict()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Prompt(**prompt.dict()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return PromptModel.model_validate(result) + else: + return None except Exception as e: return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) except: return None def get_prompts(self) -> List[PromptModel]: - return [ - PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() - ] + with get_db() as db: + + return [ + PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title - prompt.content = form_data.content - prompt.timestamp = int(time.time()) - Session.commit() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + db.commit() + return PromptModel.model_validate(prompt) except: return None def delete_prompt_by_command(self, command: str) -> bool: try: - Session.query(Prompt).filter_by(command=command).delete() - return True + with get_db() as db: + + db.query(Prompt).filter_by(command=command).delete() + return True except: return False diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 7b0df6b6b..bbbc95ed2 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -8,7 +8,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -79,26 +79,29 @@ class ChatTagsResponse(BaseModel): class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - id = str(uuid.uuid4()) - tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) - try: - result = Tag(**tag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return TagModel.model_validate(result) - else: + with get_db() as db: + + id = str(uuid.uuid4()) + tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + try: + result = Tag(**tag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None + except Exception as e: return None - except Exception as e: - return None def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: - tag = Session.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + with get_db() as db: + tag = db.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None @@ -120,98 +123,109 @@ class TagTable: } ) try: - result = ChatIdTag(**chatIdTag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + with get_db() as db: + result = ChatIdTag(**chatIdTag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] + + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: - return ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) + with get_db() as db: + + return ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + return True except Exception as e: log.error(f"delete_tag: {e}") return False @@ -220,20 +234,24 @@ class TagTable: self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + + return True except Exception as e: log.error(f"delete_tag: {e}") return False diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index f5df10637..dc0fe01c5 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -4,7 +4,7 @@ import time import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.users import Users import json @@ -83,54 +83,63 @@ class ToolsTable: def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: - tool = ToolModel( - **{ - **form_data.model_dump(), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), - } - ) - try: - result = Tool(**tool.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ToolModel.model_validate(result) - else: + with get_db() as db: + + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool(**tool.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - return ToolModel.model_validate(tool) + with get_db() as db: + + tool = db.get(Tool, id) + return ToolModel.model_validate(tool) except: return None def get_tools(self) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - tool = Session.get(Tool, id) - return tool.valves if tool.valves else {} + with get_db() as db: + + tool = db.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - Session.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_tool_by_id(id) + with get_db() as db: + + db.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + db.commit() + return self.get_tool_by_id(id) except: return None @@ -177,19 +186,21 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - tool.update(**updated) - tool.updated_at = int(time.time()) - Session.commit() - Session.refresh(tool) - return ToolModel.model_validate(tool) + with get_db() as db: + tool = db.get(Tool, id) + tool.update(**updated) + tool.updated_at = int(time.time()) + db.commit() + db.refresh(tool) + return ToolModel.model_validate(tool) except: return None def delete_tool_by_id(self, id: str) -> bool: try: - Session.query(Tool).filter_by(id=id).delete() - return True + with get_db() as db: + db.query(Tool).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 9e1e25ac6..8e3b57bba 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -6,7 +6,7 @@ from sqlalchemy import String, Column, BigInteger, Text from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, Session, get_db from apps.webui.models.chats import Chats #################### @@ -88,81 +88,92 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return user - else: - return None + with get_db() as db: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return user + else: + return None def get_user_by_id(self, id: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except Exception as e: return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) except: return None def get_user_by_email(self, email: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) except: return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) except: return None def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - users = ( - Session.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + with get_db() as db: + users = ( + db.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: - return Session.query(User).count() + with get_db() as db: + return db.query(User).count() def get_first_user(self) -> UserModel: try: - user = Session.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) except: return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"role": role}) - Session.commit() - - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + db.query(User).filter_by(id=id).update({"role": role}) + db.commit() + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -170,25 +181,28 @@ class UsersTable: self, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) - Session.commit() + with get_db() as db: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + db.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -196,21 +210,23 @@ class UsersTable: self, id: str, oauth_sub: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + with get_db() as db: + db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update(updated) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update(updated) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) except Exception as e: return None @@ -220,9 +236,10 @@ class UsersTable: result = Chats.delete_chats_by_user_id(id) if result: - # Delete User - Session.query(User).filter_by(id=id).delete() - Session.commit() + with get_db() as db: + # Delete User + db.query(User).filter_by(id=id).delete() + db.commit() return True else: @@ -232,16 +249,18 @@ class UsersTable: def update_user_api_key_by_id(self, id: str, api_key: str) -> str: try: - result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) - Session.commit() - return True if result == 1 else False + with get_db() as db: + result = db.query(User).filter_by(id=id).update({"api_key": api_key}) + db.commit() + return True if result == 1 else False except: return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: try: - user = Session.query(User).filter_by(id=id).first() - return user.api_key + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return user.api_key except Exception as e: return None From 37a5d2c06b78098ed70f52e9fefdc824ad96d531 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 23:32:46 -0700 Subject: [PATCH 04/24] Update db.py --- backend/apps/webui/internal/db.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 320ab3e07..bfdc52c11 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -53,8 +53,19 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) + + SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) Base = declarative_base() Session = scoped_session(SessionLocal) + + +# Dependency +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() From 8fe2a7bb75e222f49f177437a0e1b5279b23a37e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 23:39:16 -0700 Subject: [PATCH 05/24] fix --- backend/apps/webui/internal/db.py | 8 +++++++- backend/apps/webui/models/tools.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index bfdc52c11..333e215ea 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -62,10 +62,16 @@ Base = declarative_base() Session = scoped_session(SessionLocal) +from contextlib import contextmanager + + # Dependency -def get_db(): +def get_session(): db = SessionLocal() try: yield db finally: db.close() + + +get_db = contextmanager(get_session) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index dc0fe01c5..4cc06826a 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -119,7 +119,8 @@ class ToolsTable: return None def get_tools(self) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + with get_db() as db: + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: From 8b13755d5634d76840077c8bdcac6def93d86a70 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 00:25:45 -0700 Subject: [PATCH 06/24] Update auths.py --- backend/apps/webui/models/auths.py | 89 +++++++++++++++++------------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 48c8b543a..7698359f9 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -102,40 +102,44 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - log.info("insert_new_auth") + with get_db() as db: - id = str(uuid.uuid4()) + log.info("insert_new_auth") - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) - result = Auth(**auth.model_dump()) - db.add(result) + id = str(uuid.uuid4()) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub - ) + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = Auth(**auth.model_dump()) + db.add(result) - db.commit() - db.refresh(result) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub + ) - if result and user: - return user - else: - return None + db.commit() + db.refresh(result) + + if result and user: + return user + else: + return None def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = db.query(Auth).filter_by(email=email, active=True).first() - if auth: - if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) - return user + with get_db() as db: + + auth = db.query(Auth).filter_by(email=email, active=True).first() + if auth: + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + return user + else: + return None else: return None - else: - return None except: return None @@ -154,38 +158,47 @@ class AuthsTable: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = db.query(Auth).filter(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id) - return user + with get_db() as db: + auth = db.query(Auth).filter(email=email, active=True).first() + if auth: + user = Users.get_user_by_id(auth.id) + return user except: return None def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - result = db.query(Auth).filter_by(id=id).update({"password": new_password}) - return True if result == 1 else False + with get_db() as db: + + result = ( + db.query(Auth).filter_by(id=id).update({"password": new_password}) + ) + return True if result == 1 else False except: return False def update_email_by_id(self, id: str, email: str) -> bool: try: - result = db.query(Auth).filter_by(id=id).update({"email": email}) - return True if result == 1 else False + with get_db() as db: + + result = db.query(Auth).filter_by(id=id).update({"email": email}) + return True if result == 1 else False except: return False def delete_auth_by_id(self, id: str) -> bool: try: - # Delete User - result = Users.delete_user_by_id(id) + with get_db() as db: - if result: - db.query(Auth).filter_by(id=id).delete() + # Delete User + result = Users.delete_user_by_id(id) - return True - else: - return False + if result: + db.query(Auth).filter_by(id=id).delete() + + return True + else: + return False except: return False From 9a6cbafdef7a1a44c7e3ad914996204d07c4a77e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 00:37:05 -0700 Subject: [PATCH 07/24] fix: user valves --- backend/apps/webui/models/functions.py | 4 ++-- backend/apps/webui/models/tools.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 677f022f6..33a9d1297 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -185,7 +185,7 @@ class FunctionsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings if "functions" not in user_settings: @@ -203,7 +203,7 @@ class FunctionsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings if "functions" not in user_settings: diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 950972c2d..e7830e214 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -141,7 +141,7 @@ class ToolsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings if "tools" not in user_settings: @@ -159,7 +159,7 @@ class ToolsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings if "tools" not in user_settings: From 740b6f5c17533350ae002f62e0097d8730350c04 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 00:42:18 -0700 Subject: [PATCH 08/24] fix: pull model --- src/lib/components/admin/Settings/Models.svelte | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 57d0be135..b95829826 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -158,12 +158,14 @@ return; } - const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch( - (error) => { - toast.error(error); - return null; - } - ); + const [res, controller] = await pullModel( + localStorage.token, + sanitizedModelTag, + selectedOllamaUrlIdx + ).catch((error) => { + toast.error(error); + return null; + }); if (res) { const reader = res.body From 05277556005230847f552b55c2d896ecd57fe281 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 4 Jul 2024 12:21:09 +0100 Subject: [PATCH 09/24] use data field --- src/lib/components/chat/Chat.svelte | 31 ++++++++++++----------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 87bd9b4de..de64d2681 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -133,29 +133,24 @@ let message = history.messages[data.message_id]; const type = data?.data?.type ?? null; - if (type === "status") { - const status = { - done: data?.data?.done ?? null, - description: data?.data?.status ?? null - }; - + const payload = data?.data?.data ?? null; + if (!type || !payload) { + console.log("Data and type fields must be provided.", data); + return; + } + const status_keys = ["done", "description"]; + const citation_keys = ["document", "metadata", "source"]; + if (type === "status" && status_keys.every(key => key in payload)) { if (message.statusHistory) { - message.statusHistory.push(status); + message.statusHistory.push(payload); } else { - message.statusHistory = [status]; + message.statusHistory = [payload]; } - } else if (type === "citation") { - console.log(data); - const citation = { - document: data?.data?.document ?? null, - metadata: data?.data?.metadata ?? null, - source: data?.data?.source ?? null - }; - + } else if (type === "citation" && citation_keys.every(key => key in payload)) { if (message.citations) { - message.citations.push(citation); + message.citations.push(payload); } else { - message.citations = [citation]; + message.citations = [payload]; } } else { console.log("Unknown message type", data); From d20601dc475034d51a7617ba9ceedb84fdbacabf Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 13:53:28 +0000 Subject: [PATCH 10/24] feat: Add custom Collapsible component for collapsible content --- src/lib/components/common/Collapsible.svelte | 37 ++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/lib/components/common/Collapsible.svelte diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte new file mode 100644 index 000000000..c87ffe8ba --- /dev/null +++ b/src/lib/components/common/Collapsible.svelte @@ -0,0 +1,37 @@ + + + + +
+ +
+ +
+
\ No newline at end of file From 2389c36a70d55ee6da4164b9e085a322e488a194 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 13:55:37 +0000 Subject: [PATCH 11/24] refactor: Update WebSearchResults.svelte to use new CollapsibleComponent --- .../ResponseMessage/WebSearchResults.svelte | 146 +++++++++--------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte index 528108036..25001730e 100644 --- a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte @@ -2,17 +2,18 @@ import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import MagnifyingGlass from '$lib/components/icons/MagnifyingGlass.svelte'; - import { Collapsible } from 'bits-ui'; - import { slide } from 'svelte/transition'; + import Collapsible from '$lib/components/common/Collapsible.svelte'; + export let status = { urls: [], query: '' }; let state = false; - - +
+
@@ -22,76 +23,75 @@ {/if}
- - - - {#if status?.query} - -
- - -
- {status.query} + - -
- - - - -
-
- {/if} - - {#each status.urls as url, urlIdx} - -
- {url} -
- -
+ + + +
+
+ {/if} + + {#each status.urls as url, urlIdx} + - - + {url} +
+ +
- - -
-
- {/each} - - + + + + +
+ + {/each} +
+ + \ No newline at end of file From d5c0876a0b180cfc413a7dfb55ae4fe34f2f5d52 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 14:02:26 +0000 Subject: [PATCH 12/24] refactor: fixed new Collapsible Component to allow passed in classes chore: format --- .../ResponseMessage/WebSearchResults.svelte | 160 +++++++++--------- src/lib/components/common/Collapsible.svelte | 31 ++-- 2 files changed, 92 insertions(+), 99 deletions(-) diff --git a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte index 25001730e..4523c8482 100644 --- a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte @@ -4,94 +4,88 @@ import MagnifyingGlass from '$lib/components/icons/MagnifyingGlass.svelte'; import Collapsible from '$lib/components/common/Collapsible.svelte'; - export let status = { urls: [], query: '' }; let state = false; -
- -
- + +
+ + + {#if state} + + {:else} + + {/if} +
+
+ {#if status?.query} + +
+ - {#if state} - - {:else} - - {/if} -
-
\ No newline at end of file + + +
+ + {/if} + + {#each status.urls as url, urlIdx} + +
+ {url} +
+ +
+ + + + +
+
+ {/each} +
+
diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index c87ffe8ba..b681143a6 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -2,11 +2,11 @@ import { afterUpdate } from 'svelte'; export let open = false; - + export let className = ''; // Manage the max-height of the collapsible content for snappy transitions let contentElement: HTMLElement; - let maxHeight = '0px'; // Initial max-height + let maxHeight = '0px'; // Initial max-height // After any state update, adjust the max-height for the transition afterUpdate(() => { if (open) { @@ -15,23 +15,22 @@ } else { maxHeight = '0px'; } - }); - + }); +
+ +
+ +
+
+ - -
- -
- -
-
\ No newline at end of file From db58bb5f0f51521fa5c52e1b4e8107e6275904ad Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 14:15:16 +0000 Subject: [PATCH 13/24] refactor: Removed dependency --- src/lib/components/common/Collapsible.svelte | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index b681143a6..0a140d9dd 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -1,21 +1,19 @@
From 78ba18a680f9cce4c895279282ecf60fc581f382 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 14:55:48 +0000 Subject: [PATCH 14/24] refactor: Update Collapsible component to include dynamic margin for open state --- src/lib/components/common/Collapsible.svelte | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index 0a140d9dd..14e5785a4 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -20,7 +20,7 @@ -
+
@@ -28,7 +28,7 @@ From f611533764ece12128d5a3daaa4a0ee53e0e3b64 Mon Sep 17 00:00:00 2001 From: Karl Lee <61072264+KarlLee830@users.noreply.github.com> Date: Thu, 4 Jul 2024 22:57:32 +0800 Subject: [PATCH 15/24] i18n: Update Chinese translation --- src/lib/i18n/locales/zh-CN/translation.json | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/lib/i18n/locales/zh-CN/translation.json b/src/lib/i18n/locales/zh-CN/translation.json index d5887e2ff..366b717f0 100644 --- a/src/lib/i18n/locales/zh-CN/translation.json +++ b/src/lib/i18n/locales/zh-CN/translation.json @@ -126,7 +126,7 @@ "Connections": "外部连接", "Contact Admin for WebUI Access": "请联系管理员以获取访问权限", "Content": "内容", - "Content Extraction": "", + "Content Extraction": "内容提取", "Context Length": "上下文长度", "Continue Response": "继续生成", "Continue with {{provider}}": "使用 {{provider}} 继续", @@ -213,7 +213,7 @@ "Enable Community Sharing": "启用分享至社区", "Enable New Sign Ups": "允许新用户注册", "Enable Web Search": "启用网络搜索", - "Engine": "", + "Engine": "引擎", "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "确保您的 CSV 文件按以下顺序包含 4 列: 姓名、电子邮箱、密码、角色。", "Enter {{role}} message here": "在此处输入 {{role}} 信息", "Enter a detail about yourself for your LLMs to recall": "输入一个关于你自己的详细信息,方便你的大语言模型记住这些内容", @@ -235,7 +235,7 @@ "Enter Serpstack API Key": "输入 Serpstack API 密钥", "Enter stop sequence": "输入停止序列 (Stop Sequence)", "Enter Tavily API Key": "输入 Tavily API 密钥", - "Enter Tika Server URL": "", + "Enter Tika Server URL": "输入 Tika 服务器地址", "Enter Top K": "输入 Top K", "Enter URL (e.g. http://127.0.0.1:7860/)": "输入地址 (例如:http://127.0.0.1:7860/)", "Enter URL (e.g. http://localhost:11434)": "输入地址 (例如:http://localhost:11434)", @@ -412,7 +412,7 @@ "Open": "打开", "Open AI (Dall-E)": "Open AI (Dall-E)", "Open new chat": "打开新对话", - "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "", + "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "当前 Open WebUI 版本 (v{{OPEN_WEBUI_VERSION}}) 低于所需的版本 (v{{REQUIRED_VERSION}})", "OpenAI": "OpenAI", "OpenAI API": "OpenAI API", "OpenAI API Config": "OpenAI API 配置", @@ -428,8 +428,8 @@ "Permission denied when accessing microphone": "申请麦克风权限被拒绝", "Permission denied when accessing microphone: {{error}}": "申请麦克风权限被拒绝:{{error}}", "Personalization": "个性化", - "Pin": "", - "Pinned": "", + "Pin": "置顶", + "Pinned": "已置顶", "Pipeline deleted successfully": "Pipeline 删除成功", "Pipeline downloaded successfully": "Pipeline 下载成功", "Pipelines": "Pipeline", @@ -578,8 +578,8 @@ "This setting does not sync across browsers or devices.": "此设置不会在浏览器或设备之间同步。", "This will delete": "这将删除", "Thorough explanation": "解释较为详细", - "Tika": "", - "Tika Server URL required.": "", + "Tika": "Tika", + "Tika Server URL required.": "请输入 Tika 服务器地址。", "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "提示:在每次替换后,在对话输入中按 Tab 键可以连续更新多个变量。", "Title": "标题", "Title (e.g. Tell me a fun fact)": "标题(例如 给我讲一个有趣的事实)", @@ -614,7 +614,7 @@ "Uh-oh! There was an issue connecting to {{provider}}.": "糟糕!连接到 {{provider}} 时出现问题。", "UI": "界面", "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "未知文件类型“{{file_type}}”,将无视继续上传文件。", - "Unpin": "", + "Unpin": "取消置顶", "Update": "更新", "Update and Copy Link": "更新和复制链接", "Update password": "更新密码", From ca3f8e6cb52231a21f7c157fd1e38504665b1793 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 15:18:21 +0000 Subject: [PATCH 16/24] chore: format --- src/lib/components/common/Collapsible.svelte | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index 14e5785a4..8a3ef9690 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -20,7 +20,11 @@ -
+
From 55b7c30028c96dc58e14b563dcd26780dbea34cb Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 4 Jul 2024 18:50:09 +0100 Subject: [PATCH 17/24] simplify citation API --- src/lib/components/chat/Chat.svelte | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index de64d2681..a087e76ed 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -139,7 +139,7 @@ return; } const status_keys = ["done", "description"]; - const citation_keys = ["document", "metadata", "source"]; + const citation_keys = ["document", "url", "title"]; if (type === "status" && status_keys.every(key => key in payload)) { if (message.statusHistory) { message.statusHistory.push(payload); @@ -147,10 +147,15 @@ message.statusHistory = [payload]; } } else if (type === "citation" && citation_keys.every(key => key in payload)) { + const citation = { + document: [payload.document], + metadata: [{source: payload.url}], + source: {name: payload.title} + }; if (message.citations) { - message.citations.push(payload); + message.citations.push(citation); } else { - message.citations = [payload]; + message.citations = [citation]; } } else { console.log("Unknown message type", data); From 67c2ab006d06e442c4ca7cc4e0293e119f67f715 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 13:41:18 -0700 Subject: [PATCH 18/24] fix: pipe custom model --- backend/apps/webui/main.py | 76 ++++++++++++++++++++++++++++++++++++++ backend/main.py | 6 ++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 552edf7fa..745157ac6 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,8 +19,13 @@ from apps.webui.routers import ( functions, ) from apps.webui.models.functions import Functions +from apps.webui.models.models import Models + from apps.webui.utils import load_function_module_by_id + from utils.misc import stream_message_template +from utils.task import prompt_template + from config import ( WEBUI_BUILD_HASH, @@ -186,6 +191,77 @@ async def get_pipe_models(): async def generate_function_chat_completion(form_data, user): + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + if model_info.params.get("temperature", None) is not None: + form_data["temperature"] = float(model_info.params.get("temperature")) + + if model_info.params.get("top_p", None): + form_data["top_p"] = int(model_info.params.get("top_p", None)) + + if model_info.params.get("max_tokens", None): + form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) + + if model_info.params.get("frequency_penalty", None): + form_data["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) + + if model_info.params.get("seed", None): + form_data["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + form_data["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + system = model_info.params.get("system", None) + if system: + system = prompt_template( + system, + **( + { + "user_name": user.name, + "user_location": ( + user.info.get("location") if user.info else None + ), + } + if user + else {} + ), + ) + # Check if the payload already has a system message + # If not, add a system message to the payload + if form_data.get("messages"): + for message in form_data["messages"]: + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + form_data["messages"].insert( + 0, + { + "role": "system", + "content": system, + }, + ) + + else: + pass + async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/backend/main.py b/backend/main.py index 8f818c85b..f2019b30f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -975,12 +975,16 @@ async def get_all_models(): model["info"] = custom_model.model_dump() else: owned_by = "openai" + pipe = None + for model in models: if ( custom_model.base_model_id == model["id"] or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] break models.append( @@ -992,11 +996,11 @@ async def get_all_models(): "owned_by": owned_by, "info": custom_model.model_dump(), "preset": True, + **({"pipe": pipe} if pipe is not None else {}), } ) app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS return models From 838134637818ae64127bcb27a9208b0466b438d4 Mon Sep 17 00:00:00 2001 From: Peter De-Ath Date: Fri, 5 Jul 2024 02:05:59 +0100 Subject: [PATCH 19/24] enh: add sideways scrolling to settings tabs container --- src/lib/components/admin/Settings.svelte | 17 ++++++++++-- src/lib/components/chat/SettingsModal.svelte | 29 +++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 5538a11cf..24cf595a7 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -1,5 +1,5 @@
-
- -
-
- + {#if open} +
+ +
+ {/if} +
From 1436bb7c61b1df4dba2b5b383ecb8c86ec452f37 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 5 Jul 2024 23:38:53 -0700 Subject: [PATCH 24/24] enh: handle peewee migration --- backend/apps/webui/internal/db.py | 36 +++++++++++-- backend/apps/webui/internal/wrappers.py | 72 +++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 backend/apps/webui/internal/wrappers.py diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 333e215ea..8437ae4fa 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -2,6 +2,10 @@ import os import logging import json from contextlib import contextmanager + +from peewee_migrate import Router +from apps.webui.internal.wrappers import register_connection + from typing import Optional, Any from typing_extensions import Self @@ -46,6 +50,35 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): else: pass + +# Workaround to handle the peewee migration +# This is required to ensure the peewee migration is handled before the alembic migration +def handle_peewee_migration(): + try: + db = register_connection(DATABASE_URL) + migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations" + router = Router(db, logger=log, migrate_dir=migrate_dir) + router.run() + db.close() + + # check if db connection has been closed + + except Exception as e: + log.error(f"Failed to initialize the database connection: {e}") + raise + + finally: + # Properly closing the database connection + if db and not db.is_closed(): + db.close() + + # Assert if db connection has been closed + assert db.is_closed(), "Database connection is still open." + + +handle_peewee_migration() + + SQLALCHEMY_DATABASE_URL = DATABASE_URL if "sqlite" in SQLALCHEMY_DATABASE_URL: engine = create_engine( @@ -62,9 +95,6 @@ Base = declarative_base() Session = scoped_session(SessionLocal) -from contextlib import contextmanager - - # Dependency def get_session(): db = SessionLocal() diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py new file mode 100644 index 000000000..2b5551ce2 --- /dev/null +++ b/backend/apps/webui/internal/wrappers.py @@ -0,0 +1,72 @@ +from contextvars import ContextVar +from peewee import * +from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError + +import logging +from playhouse.db_url import connect, parse +from playhouse.shortcuts import ReconnectMixin + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["DB"]) + +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +class PeeweeConnectionState(object): + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + value = self._state.get()[name] + return value + + +class CustomReconnectMixin(ReconnectMixin): + reconnect_errors = ( + # psycopg2 + (OperationalError, "termin"), + (InterfaceError, "closed"), + # peewee + (PeeWeeInterfaceError, "closed"), + ) + + +class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): + pass + + +def register_connection(db_url): + db = connect(db_url) + if isinstance(db, PostgresqlDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to PostgreSQL database") + + # Get the connection details + connection = parse(db_url) + + # Use our custom database class that supports reconnection + db = ReconnectingPostgresqlDatabase( + connection["database"], + user=connection["user"], + password=connection["password"], + host=connection["host"], + port=connection["port"], + ) + db.connect(reuse_if_open=True) + elif isinstance(db, SqliteDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to SQLite database") + else: + raise ValueError("Unsupported database connection") + return db