diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 320ab3e07..8437ae4fa 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -2,6 +2,10 @@ import os import logging import json from contextlib import contextmanager + +from peewee_migrate import Router +from apps.webui.internal.wrappers import register_connection + from typing import Optional, Any from typing_extensions import Self @@ -46,6 +50,35 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): else: pass + +# Workaround to handle the peewee migration +# This is required to ensure the peewee migration is handled before the alembic migration +def handle_peewee_migration(): + try: + db = register_connection(DATABASE_URL) + migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations" + router = Router(db, logger=log, migrate_dir=migrate_dir) + router.run() + db.close() + + # check if db connection has been closed + + except Exception as e: + log.error(f"Failed to initialize the database connection: {e}") + raise + + finally: + # Properly closing the database connection + if db and not db.is_closed(): + db.close() + + # Assert if db connection has been closed + assert db.is_closed(), "Database connection is still open." + + +handle_peewee_migration() + + SQLALCHEMY_DATABASE_URL = DATABASE_URL if "sqlite" in SQLALCHEMY_DATABASE_URL: engine = create_engine( @@ -53,8 +86,22 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) + + SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) Base = declarative_base() Session = scoped_session(SessionLocal) + + +# Dependency +def get_session(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +get_db = contextmanager(get_session) diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py new file mode 100644 index 000000000..2b5551ce2 --- /dev/null +++ b/backend/apps/webui/internal/wrappers.py @@ -0,0 +1,72 @@ +from contextvars import ContextVar +from peewee import * +from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError + +import logging +from playhouse.db_url import connect, parse +from playhouse.shortcuts import ReconnectMixin + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["DB"]) + +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +class PeeweeConnectionState(object): + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + value = self._state.get()[name] + return value + + +class CustomReconnectMixin(ReconnectMixin): + reconnect_errors = ( + # psycopg2 + (OperationalError, "termin"), + (InterfaceError, "closed"), + # peewee + (PeeWeeInterfaceError, "closed"), + ) + + +class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): + pass + + +def register_connection(db_url): + db = connect(db_url) + if isinstance(db, PostgresqlDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to PostgreSQL database") + + # Get the connection details + connection = parse(db_url) + + # Use our custom database class that supports reconnection + db = ReconnectingPostgresqlDatabase( + connection["database"], + user=connection["user"], + password=connection["password"], + host=connection["host"], + port=connection["port"], + ) + db.connect(reuse_if_open=True) + elif isinstance(db, SqliteDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to SQLite database") + else: + raise ValueError("Unsupported database connection") + return db diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 10f271bed..a712bbe28 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,8 +19,13 @@ from apps.webui.routers import ( functions, ) from apps.webui.models.functions import Functions +from apps.webui.models.models import Models + from apps.webui.utils import load_function_module_by_id + from utils.misc import stream_message_template +from utils.task import prompt_template + from config import ( WEBUI_BUILD_HASH, @@ -186,6 +191,77 @@ async def get_pipe_models(): async def generate_function_chat_completion(form_data, user): + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + if model_info.params.get("temperature", None) is not None: + form_data["temperature"] = float(model_info.params.get("temperature")) + + if model_info.params.get("top_p", None): + form_data["top_p"] = int(model_info.params.get("top_p", None)) + + if model_info.params.get("max_tokens", None): + form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) + + if model_info.params.get("frequency_penalty", None): + form_data["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) + + if model_info.params.get("seed", None): + form_data["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + form_data["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + system = model_info.params.get("system", None) + if system: + system = prompt_template( + system, + **( + { + "user_name": user.name, + "user_location": ( + user.info.get("location") if user.info else None + ), + } + if user + else {} + ), + ) + # Check if the payload already has a system message + # If not, add a system message to the payload + if form_data.get("messages"): + for message in form_data["messages"]: + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + form_data["messages"].insert( + 0, + { + "role": "system", + "content": system, + }, + ) + + else: + pass + async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 560d9a686..7698359f9 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -7,7 +7,7 @@ from sqlalchemy import String, Column, Boolean, Text from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -102,40 +102,44 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - log.info("insert_new_auth") + with get_db() as db: - id = str(uuid.uuid4()) + log.info("insert_new_auth") - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) - result = Auth(**auth.model_dump()) - Session.add(result) + id = str(uuid.uuid4()) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub - ) + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = Auth(**auth.model_dump()) + db.add(result) - Session.commit() - Session.refresh(result) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub + ) - if result and user: - return user - else: - return None + db.commit() + db.refresh(result) + + if result and user: + return user + else: + return None def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = Session.query(Auth).filter_by(email=email, active=True).first() - if auth: - if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) - return user + with get_db() as db: + + auth = db.query(Auth).filter_by(email=email, active=True).first() + if auth: + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + return user + else: + return None else: return None - else: - return None except: return None @@ -154,40 +158,47 @@ class AuthsTable: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = Session.query(Auth).filter(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id) - return user + with get_db() as db: + auth = db.query(Auth).filter(email=email, active=True).first() + if auth: + user = Users.get_user_by_id(auth.id) + return user except: return None def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - result = ( - Session.query(Auth).filter_by(id=id).update({"password": new_password}) - ) - return True if result == 1 else False + with get_db() as db: + + result = ( + db.query(Auth).filter_by(id=id).update({"password": new_password}) + ) + return True if result == 1 else False except: return False def update_email_by_id(self, id: str, email: str) -> bool: try: - result = Session.query(Auth).filter_by(id=id).update({"email": email}) - return True if result == 1 else False + with get_db() as db: + + result = db.query(Auth).filter_by(id=id).update({"email": email}) + return True if result == 1 else False except: return False def delete_auth_by_id(self, id: str) -> bool: try: - # Delete User - result = Users.delete_user_by_id(id) + with get_db() as db: - if result: - Session.query(Auth).filter_by(id=id).delete() + # Delete User + result = Users.delete_user_by_id(id) - return True - else: - return False + if result: + db.query(Auth).filter_by(id=id).delete() + + return True + else: + return False except: return False diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d6829ee7b..8d2e6b104 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -7,7 +7,7 @@ import time from sqlalchemy import Column, String, BigInteger, Boolean, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db #################### @@ -79,87 +79,99 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) + with get_db() as db: - result = Chat(**chat.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - return ChatModel.model_validate(result) if result else None + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: - chat_obj = Session.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) - Session.commit() - Session.refresh(chat_obj) + with get_db() as db: - return ChatModel.model_validate(chat_obj) + chat_obj = db.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + db.commit() + db.refresh(chat_obj) + + return ChatModel.model_validate(chat_obj) except Exception as e: return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - # Get the existing chat to share - chat = Session.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - Session.add(shared_result) - Session.commit() - Session.refresh(shared_result) - # Update the original chat with the share_id - result = ( - Session.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + with get_db() as db: - return shared_chat if (shared_result and result) else None + # Get the existing chat to share + chat = db.get(Chat, chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + db.add(shared_result) + db.commit() + db.refresh(shared_result) + # Update the original chat with the share_id + result = ( + db.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) + ) + + return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: - print("update_shared_chat_by_id") - chat = Session.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - Session.commit() - Session.refresh(chat) + with get_db() as db: - return self.get_chat_by_id(chat.share_id) + print("update_shared_chat_by_id") + chat = db.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + db.commit() + db.refresh(chat) + + return self.get_chat_by_id(chat.share_id) except: return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + return True except: return False @@ -167,42 +179,50 @@ class ChatTable: self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.share_id = share_id - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.share_id = share_id + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.archived = not chat.archived - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.archived = not chat.archived + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + return True except: return False def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, @@ -211,110 +231,136 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - query = Session.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + return ChatModel.model_validate(chat) except: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(share_id=id).first() + with get_db() as db: - if chat: - return self.get_chat_by_id(id) - else: - return None + chat = db.query(Chat).filter_by(share_id=id).first() + + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def delete_chat_by_id(self, id: str) -> bool: try: - Session.query(Chat).filter_by(id=id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - Session.query(Chat).filter_by(id=id, user_id=user_id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id, user_id=user_id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: - self.delete_shared_chats_by_user_id(user_id) - Session.query(Chat).filter_by(user_id=user_id).delete() - return True + with get_db() as db: + + self.delete_shared_chats_by_user_id(user_id) + + db.query(Chat).filter_by(user_id=user_id).delete() + return True except: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + with get_db() as db: - return True + chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + + db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + + return True except: return False diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 1b69d44a5..16145c4ac 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -74,51 +74,59 @@ class DocumentsTable: def insert_new_doc( self, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: - document = DocumentModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "timestamp": int(time.time()), - } - ) + with get_db() as db: - try: - result = Document(**document.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: + document = DocumentModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "timestamp": int(time.time()), + } + ) + + try: + result = Document(**document.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None + except: return None - except: - return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - document = Session.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + with get_db() as db: + + document = db.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None def get_docs(self) -> List[DocumentModel]: - return [ - DocumentModel.model_validate(doc) for doc in Session.query(Document).all() - ] + with get_db() as db: + + return [ + DocumentModel.model_validate(doc) for doc in db.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - Session.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(form_data.name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None @@ -131,22 +139,26 @@ class DocumentsTable: doc_content = json.loads(doc.content if doc.content else "{}") doc_content = {**doc_content, **updated} - Session.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None def delete_doc_by_name(self, name: str) -> bool: try: - Session.query(Document).filter_by(name=name).delete() - return True + with get_db() as db: + + db.query(Document).filter_by(name=name).delete() + return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index ce904215d..58058e907 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db import json @@ -61,50 +61,62 @@ class FileForm(BaseModel): class FilesTable: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - file = FileModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "created_at": int(time.time()), - } - ) + with get_db() as db: - try: - result = File(**file.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FileModel.model_validate(result) - else: + file = FileModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + } + ) + + try: + result = File(**file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_file_by_id(self, id: str) -> Optional[FileModel]: - try: - file = Session.get(File, id) - return FileModel.model_validate(file) - except: - return None + with get_db() as db: + + try: + file = db.get(File, id) + return FileModel.model_validate(file) + except: + return None def get_files(self) -> List[FileModel]: - return [FileModel.model_validate(file) for file in Session.query(File).all()] + with get_db() as db: + + return [FileModel.model_validate(file) for file in db.query(File).all()] def delete_file_by_id(self, id: str) -> bool: - try: - Session.query(File).filter_by(id=id).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).filter_by(id=id).delete() + return True + except: + return False def delete_all_files(self) -> bool: - try: - Session.query(File).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).delete() + return True + except: + return False Files = FilesTable() diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 64ed4f3cc..cdc1bd334 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db from apps.webui.models.users import Users import json @@ -91,6 +91,7 @@ class FunctionsTable: def insert_new_function( self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: + function = FunctionModel( **{ **form_data.model_dump(), @@ -102,88 +103,102 @@ class FunctionsTable: ) try: - result = Function(**function.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + with get_db() as db: + result = Function(**function.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - function = Session.get(Function, id) - return FunctionModel.model_validate(function) + with get_db() as db: + + function = db.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(is_active=True).all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(is_active=True).all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type=type, is_active=True) - .all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(type=type).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type=type, is_active=True) + .all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(type=type).all() + ] def get_global_filter_functions(self) -> List[FunctionModel]: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type="filter", is_active=True, is_global=True) - .all() - ] + with get_db() as db: + + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type="filter", is_active=True, is_global=True) + .all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: - try: - function = Session.get(Function, id) - return function.valves if function.valves else {} - except Exception as e: - print(f"An error occurred: {e}") - return None + with get_db() as db: + + try: + function = db.get(Function, id) + return function.valves if function.valves else {} + except Exception as e: + print(f"An error occurred: {e}") + return None def update_function_valves_by_id( self, id: str, valves: dict ) -> Optional[FunctionValves]: - try: - function = Session.get(Function, id) - function.valves = valves - function.updated_at = int(time.time()) - Session.commit() - Session.refresh(function) - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + function = db.get(Function, id) + function.valves = valves + function.updated_at = int(time.time()) + db.commit() + db.refresh(function) + return self.get_function_by_id(id) + except: + return None def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings if "functions" not in user_settings: @@ -199,9 +214,10 @@ class FunctionsTable: def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings if "functions" not in user_settings: @@ -220,37 +236,43 @@ class FunctionsTable: return None def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - try: - Session.query(Function).filter_by(id=id).update( - { - **updated, - "updated_at": int(time.time()), - } - ) - Session.commit() - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_function_by_id(id) + except: + return None def deactivate_all_functions(self) -> Optional[bool]: - try: - Session.query(Function).update( - { - "is_active": False, - "updated_at": int(time.time()), - } - ) - Session.commit() - return True - except: - return None + with get_db() as db: + + try: + db.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) + db.commit() + return True + except: + return None def delete_function_by_id(self, id: str) -> bool: - try: - Session.query(Function).filter_by(id=id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).delete() + return True + except: + return False Functions = FunctionsTable() diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 1f03318fd..662bbedfe 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -3,7 +3,7 @@ from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import time import uuid @@ -45,82 +45,98 @@ class MemoriesTable: user_id: str, content: str, ) -> Optional[MemoryModel]: - id = str(uuid.uuid4()) - memory = MemoryModel( - **{ - "id": id, - "user_id": user_id, - "content": content, - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) - result = Memory(**memory.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + with get_db() as db: + id = str(uuid.uuid4()) + + memory = MemoryModel( + **{ + "id": id, + "user_id": user_id, + "content": content, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + result = Memory(**memory.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, id: str, content: str, ) -> Optional[MemoryModel]: - try: - Session.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_memory_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + db.commit() + return self.get_memory_by_id(id) + except: + return None def get_memories(self) -> List[MemoryModel]: - try: - memories = Session.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: - try: - memories = Session.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - try: - memory = Session.get(Memory, id) - return MemoryModel.model_validate(memory) - except: - return None + with get_db() as db: + + try: + memory = db.get(Memory, id) + return MemoryModel.model_validate(memory) + except: + return None def delete_memory_by_id(self, id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id).delete() - return True + with get_db() as db: - except: - return False + try: + db.query(Memory).filter_by(id=id).delete() + return True + + except: + return False def delete_memories_by_user_id(self, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(user_id=user_id).delete() + return True + except: + return False def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id, user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id, user_id=user_id).delete() + return True + except: + return False Memories = MemoriesTable() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 6543edefc..c95c36c7d 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -5,7 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -126,39 +126,46 @@ class ModelsTable: } ) try: - result = Model(**model.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Model(**model.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None def get_all_models(self) -> List[ModelModel]: - return [ - ModelModel.model_validate(model) for model in Session.query(Model).all() - ] + with get_db() as db: + + return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - model = Session.get(Model, id) - return ModelModel.model_validate(model) + with get_db() as db: + + model = db.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: - # update only the fields that are present in the model - model = Session.query(Model).get(id) - model.update(**model.model_dump()) - Session.commit() - Session.refresh(model) - return ModelModel.model_validate(model) + with get_db() as db: + + # update only the fields that are present in the model + model = db.query(Model).get(id) + model.update(**model.model_dump()) + db.commit() + db.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) @@ -166,8 +173,10 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - Session.query(Model).filter_by(id=id).delete() - return True + with get_db() as db: + + db.query(Model).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index ab8cc04ce..2af2ce22c 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -4,7 +4,7 @@ import time from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -60,46 +60,56 @@ class PromptsTable: ) try: - result = Prompt(**prompt.dict()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Prompt(**prompt.dict()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return PromptModel.model_validate(result) + else: + return None except Exception as e: return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) except: return None def get_prompts(self) -> List[PromptModel]: - return [ - PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() - ] + with get_db() as db: + + return [ + PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title - prompt.content = form_data.content - prompt.timestamp = int(time.time()) - Session.commit() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + db.commit() + return PromptModel.model_validate(prompt) except: return None def delete_prompt_by_command(self, command: str) -> bool: try: - Session.query(Prompt).filter_by(command=command).delete() - return True + with get_db() as db: + + db.query(Prompt).filter_by(command=command).delete() + return True except: return False diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 7b0df6b6b..bbbc95ed2 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -8,7 +8,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -79,26 +79,29 @@ class ChatTagsResponse(BaseModel): class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - id = str(uuid.uuid4()) - tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) - try: - result = Tag(**tag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return TagModel.model_validate(result) - else: + with get_db() as db: + + id = str(uuid.uuid4()) + tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + try: + result = Tag(**tag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None + except Exception as e: return None - except Exception as e: - return None def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: - tag = Session.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + with get_db() as db: + tag = db.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None @@ -120,98 +123,109 @@ class TagTable: } ) try: - result = ChatIdTag(**chatIdTag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + with get_db() as db: + result = ChatIdTag(**chatIdTag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] + + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: - return ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) + with get_db() as db: + + return ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + return True except Exception as e: log.error(f"delete_tag: {e}") return False @@ -220,20 +234,24 @@ class TagTable: self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + + return True except Exception as e: log.error(f"delete_tag: {e}") return False diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index f5df10637..b3964a9b8 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -4,7 +4,7 @@ import time import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.users import Users import json @@ -83,54 +83,64 @@ class ToolsTable: def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: - tool = ToolModel( - **{ - **form_data.model_dump(), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), - } - ) - try: - result = Tool(**tool.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ToolModel.model_validate(result) - else: + with get_db() as db: + + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool(**tool.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - return ToolModel.model_validate(tool) + with get_db() as db: + + tool = db.get(Tool, id) + return ToolModel.model_validate(tool) except: return None def get_tools(self) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] + with get_db() as db: + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - tool = Session.get(Tool, id) - return tool.valves if tool.valves else {} + with get_db() as db: + + tool = db.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - Session.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_tool_by_id(id) + with get_db() as db: + + db.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + db.commit() + return self.get_tool_by_id(id) except: return None @@ -139,7 +149,7 @@ class ToolsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings if "tools" not in user_settings: @@ -157,7 +167,7 @@ class ToolsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings if "tools" not in user_settings: @@ -177,19 +187,21 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - tool.update(**updated) - tool.updated_at = int(time.time()) - Session.commit() - Session.refresh(tool) - return ToolModel.model_validate(tool) + with get_db() as db: + tool = db.get(Tool, id) + tool.update(**updated) + tool.updated_at = int(time.time()) + db.commit() + db.refresh(tool) + return ToolModel.model_validate(tool) except: return None def delete_tool_by_id(self, id: str) -> bool: try: - Session.query(Tool).filter_by(id=id).delete() - return True + with get_db() as db: + db.query(Tool).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 9e1e25ac6..8e3b57bba 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -6,7 +6,7 @@ from sqlalchemy import String, Column, BigInteger, Text from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, Session, get_db from apps.webui.models.chats import Chats #################### @@ -88,81 +88,92 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return user - else: - return None + with get_db() as db: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return user + else: + return None def get_user_by_id(self, id: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except Exception as e: return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) except: return None def get_user_by_email(self, email: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) except: return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) except: return None def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - users = ( - Session.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + with get_db() as db: + users = ( + db.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: - return Session.query(User).count() + with get_db() as db: + return db.query(User).count() def get_first_user(self) -> UserModel: try: - user = Session.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) except: return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"role": role}) - Session.commit() - - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + db.query(User).filter_by(id=id).update({"role": role}) + db.commit() + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -170,25 +181,28 @@ class UsersTable: self, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) - Session.commit() + with get_db() as db: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + db.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -196,21 +210,23 @@ class UsersTable: self, id: str, oauth_sub: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + with get_db() as db: + db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update(updated) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update(updated) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) except Exception as e: return None @@ -220,9 +236,10 @@ class UsersTable: result = Chats.delete_chats_by_user_id(id) if result: - # Delete User - Session.query(User).filter_by(id=id).delete() - Session.commit() + with get_db() as db: + # Delete User + db.query(User).filter_by(id=id).delete() + db.commit() return True else: @@ -232,16 +249,18 @@ class UsersTable: def update_user_api_key_by_id(self, id: str, api_key: str) -> str: try: - result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) - Session.commit() - return True if result == 1 else False + with get_db() as db: + result = db.query(User).filter_by(id=id).update({"api_key": api_key}) + db.commit() + return True if result == 1 else False except: return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: try: - user = Session.query(User).filter_by(id=id).first() - return user.api_key + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return user.api_key except Exception as e: return None diff --git a/backend/main.py b/backend/main.py index cb0cef4f4..2da19c5c7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -999,12 +999,16 @@ async def get_all_models(): model["info"] = custom_model.model_dump() else: owned_by = "openai" + pipe = None + for model in models: if ( custom_model.base_model_id == model["id"] or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] break models.append( @@ -1016,11 +1020,11 @@ async def get_all_models(): "owned_by": owned_by, "info": custom_model.model_dump(), "preset": True, + **({"pipe": pipe} if pipe is not None else {}), } ) app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS return models diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 5538a11cf..24cf595a7 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -1,5 +1,5 @@