From 8ae605ec4befb751dfbedfdb756d0e787da6ca7c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 13 Oct 2024 01:00:38 -0700 Subject: [PATCH] fix: multi-user tags issue --- backend/open_webui/apps/webui/models/tags.py | 18 +++-- .../open_webui/apps/webui/routers/chats.py | 8 +-- .../versions/3ab32c4b8f59_update_tags.py | 67 +++++++++++++++++++ 3 files changed, 84 insertions(+), 9 deletions(-) create mode 100644 backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/apps/webui/models/tags.py index ef209b565..7424a2660 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/apps/webui/models/tags.py @@ -8,7 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, JSON +from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -19,11 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### class Tag(Base): __tablename__ = "tag" - id = Column(String, primary_key=True) + id = Column(String) name = Column(String) user_id = Column(String) meta = Column(JSON, nullable=True) + # Unique constraint ensuring (id, user_id) is unique, not just the `id` column + __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) + class TagModel(BaseModel): id: str @@ -57,7 +60,8 @@ class TagTable: return TagModel.model_validate(result) else: return None - except Exception: + except Exception as e: + print(e) return None def get_tag_by_name_and_user_id( @@ -78,11 +82,15 @@ class TagTable: for tag in (db.query(Tag).filter_by(user_id=user_id).all()) ] - def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]: + def get_tags_by_ids_and_user_id( + self, ids: list[str], user_id: str + ) -> list[TagModel]: with get_db() as db: return [ TagModel.model_validate(tag) - for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all()) + for tag in ( + db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all() + ) ] def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index 6a9c26f8c..b919d1447 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -465,7 +465,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids(tags) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -494,7 +494,7 @@ async def add_tag_by_id_and_tag_name( chat = Chats.get_chat_by_id_and_user_id(id, user.id) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids(tags) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -519,7 +519,7 @@ async def delete_tag_by_id_and_tag_name( chat = Chats.get_chat_by_id_and_user_id(id, user.id) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids(tags) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -543,7 +543,7 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids(tags) + return Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND diff --git a/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py new file mode 100644 index 000000000..7c7126e2f --- /dev/null +++ b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py @@ -0,0 +1,67 @@ +"""Update tags + +Revision ID: 3ab32c4b8f59 +Revises: 1af9b942657b +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update, column +from sqlalchemy.engine.reflection import Inspector + +import json + +revision = "3ab32c4b8f59" +down_revision = "1af9b942657b" +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + + # Inspecting the 'tag' table constraints and structure + existing_pk = inspector.get_pk_constraint("tag") + unique_constraints = inspector.get_unique_constraints("tag") + existing_indexes = inspector.get_indexes("tag") + + print(existing_pk, unique_constraints) + + with op.batch_alter_table("tag", schema=None) as batch_op: + # Drop unique constraints that could conflict with new primary key + for constraint in unique_constraints: + if constraint["name"] == "uq_id_user_id": + batch_op.drop_constraint(constraint["name"], type_="unique") + + for index in existing_indexes: + if index["unique"]: + # Drop the unique index + batch_op.drop_index(index["name"]) + + # Drop existing primary key constraint if it exists + if existing_pk and existing_pk.get("constrained_columns"): + batch_op.drop_constraint(existing_pk["name"], type_="primary") + + # Immediately after dropping the old primary key, create the new one + batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"]) + + +def downgrade(): + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + + current_pk = inspector.get_pk_constraint("tag") + + with op.batch_alter_table("tag", schema=None) as batch_op: + # Drop the current primary key first, if it matches the one we know we added in upgrade + if current_pk and "pk_id_user_id" == current_pk.get("name"): + batch_op.drop_constraint("pk_id_user_id", type_="primary") + + # Restore the original primary key + batch_op.create_primary_key("pk_id", ["id"]) + + # Since primary key on just 'id' is restored, we now add back any unique constraints if necessary + batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])