diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index 1cf56c351..d6829ee7b 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -85,9 +85,7 @@ class ChatTable: "id": id, "user_id": user_id, "title": ( - form_data.chat["title"] - if "title" in form_data.chat - else "New Chat" + form_data.chat["title"] if "title" in form_data.chat else "New Chat" ), "chat": json.dumps(form_data.chat), "created_at": int(time.time()), @@ -197,14 +195,14 @@ class ChatTable: def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index f0bd6e291..1f03318fd 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -115,9 +115,7 @@ class MemoriesTable: except: return False - def delete_memory_by_id_and_user_id( - self, id: str, user_id: str - ) -> bool: + def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: Session.query(Memory).filter_by(id=id, user_id=user_id).delete() return True diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 7d1da54ff..6543edefc 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -140,7 +140,9 @@ class ModelsTable: return None def get_all_models(self) -> List[ModelModel]: - return [ModelModel.model_validate(model) for model in Session.query(Model).all()] + return [ + ModelModel.model_validate(model) for model in Session.query(Model).all() + ] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 87238c2a3..7b0df6b6b 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -207,9 +207,7 @@ class TagTable: log.debug(f"res: {res}") Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) + tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) if tag_count == 0: # Remove tag item from Tag col as well Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() @@ -230,9 +228,7 @@ class TagTable: log.debug(f"res: {res}") Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) + tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) if tag_count == 0: # Remove tag item from Tag col as well Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() diff --git a/backend/main.py b/backend/main.py index ad519bdcb..a29fde198 100644 --- a/backend/main.py +++ b/backend/main.py @@ -793,6 +793,7 @@ app.add_middleware( allow_headers=["*"], ) + @app.middleware("http") async def commit_session_after_request(request: Request, call_next): response = await call_next(request) diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py index bd49d1b43..50deac526 100644 --- a/backend/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -5,6 +5,7 @@ Revises: Create Date: 2024-06-24 13:15:33.808998 """ + from typing import Sequence, Union from alembic import op @@ -13,7 +14,7 @@ import apps.webui.internal.db from migrations.util import get_existing_tables # revision identifiers, used by Alembic. -revision: str = '7e5b5dc7342b' +revision: str = "7e5b5dc7342b" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -24,163 +25,175 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### if "auth" not in existing_tables: - op.create_table('auth', - sa.Column('id', sa.String(), nullable=False), - sa.Column('email', sa.String(), nullable=True), - sa.Column('password', sa.Text(), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "auth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=True), + sa.Column("password", sa.Text(), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "chat" not in existing_tables: - op.create_table('chat', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.Text(), nullable=True), - sa.Column('chat', sa.Text(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('share_id', sa.Text(), nullable=True), - sa.Column('archived', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('share_id') + op.create_table( + "chat", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("chat", sa.Text(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("share_id", sa.Text(), nullable=True), + sa.Column("archived", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("share_id"), ) if "chatidtag" not in existing_tables: - op.create_table('chatidtag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('tag_name', sa.String(), nullable=True), - sa.Column('chat_id', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "chatidtag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("tag_name", sa.String(), nullable=True), + sa.Column("chat_id", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "document" not in existing_tables: - op.create_table('document', - sa.Column('collection_name', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('title', sa.Text(), nullable=True), - sa.Column('filename', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('collection_name'), - sa.UniqueConstraint('name') + op.create_table( + "document", + sa.Column("collection_name", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("filename", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("collection_name"), + sa.UniqueConstraint("name"), ) if "file" not in existing_tables: - op.create_table('file', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('filename', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "file", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("filename", sa.Text(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "function" not in existing_tables: - op.create_table('function', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('type', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "function", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "memory" not in existing_tables: - op.create_table('memory', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "memory", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "model" not in existing_tables: - op.create_table('model', - sa.Column('id', sa.Text(), nullable=False), - sa.Column('user_id', sa.Text(), nullable=True), - sa.Column('base_model_id', sa.Text(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "model", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=True), + sa.Column("base_model_id", sa.Text(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "prompt" not in existing_tables: - op.create_table('prompt', - sa.Column('command', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('command') + op.create_table( + "prompt", + sa.Column("command", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("command"), ) if "tag" not in existing_tables: - op.create_table('tag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('data', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("data", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "tool" not in existing_tables: - op.create_table('tool', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tool", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "user" not in existing_tables: - op.create_table('user', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('email', sa.String(), nullable=True), - sa.Column('role', sa.String(), nullable=True), - sa.Column('profile_image_url', sa.Text(), nullable=True), - sa.Column('last_active_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('api_key', sa.String(), nullable=True), - sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('api_key') + op.create_table( + "user", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("profile_image_url", sa.Text(), nullable=True), + sa.Column("last_active_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("api_key"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('user') - op.drop_table('tool') - op.drop_table('tag') - op.drop_table('prompt') - op.drop_table('model') - op.drop_table('memory') - op.drop_table('function') - op.drop_table('file') - op.drop_table('document') - op.drop_table('chatidtag') - op.drop_table('chat') - op.drop_table('auth') + op.drop_table("user") + op.drop_table("tool") + op.drop_table("tag") + op.drop_table("prompt") + op.drop_table("model") + op.drop_table("memory") + op.drop_table("function") + op.drop_table("file") + op.drop_table("document") + op.drop_table("chatidtag") + op.drop_table("chat") + op.drop_table("auth") # ### end Alembic commands ### diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py index 6d2dd35b1..f4661b625 100644 --- a/backend/test/apps/webui/routers/test_chats.py +++ b/backend/test/apps/webui/routers/test_chats.py @@ -91,6 +91,7 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") from apps.webui.internal.db import Session + Session.commit() with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url("/all/archived")) diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index f8d6d4ff7..4e99dcc2f 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -110,6 +110,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): def _check_db_connection(self): from apps.webui.internal.db import Session + retries = 10 while retries > 0: try: @@ -133,6 +134,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): def teardown_method(self): from apps.webui.internal.db import Session + # rollback everything not yet committed Session.commit()