mirror of
https://github.com/open-webui/open-webui
synced 2025-01-31 06:49:03 +00:00
fix: multi-user tags issue
This commit is contained in:
parent
5273dc4535
commit
8ae605ec4b
@ -8,7 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, JSON
|
||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@ -19,11 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
####################
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
id = Column(String, primary_key=True)
|
||||
id = Column(String)
|
||||
name = Column(String)
|
||||
user_id = Column(String)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
||||
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
|
||||
|
||||
|
||||
class TagModel(BaseModel):
|
||||
id: str
|
||||
@ -57,7 +60,8 @@ class TagTable:
|
||||
return TagModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
@ -78,11 +82,15 @@ class TagTable:
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
|
||||
def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str
|
||||
) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
|
||||
for tag in (
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
||||
)
|
||||
]
|
||||
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
|
@ -465,7 +465,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@ -494,7 +494,7 @@ async def add_tag_by_id_and_tag_name(
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
@ -519,7 +519,7 @@ async def delete_tag_by_id_and_tag_name(
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@ -543,7 +543,7 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
@ -0,0 +1,67 @@
|
||||
"""Update tags
|
||||
|
||||
Revision ID: 3ab32c4b8f59
|
||||
Revises: 1af9b942657b
|
||||
Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update, column
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
revision = "3ab32c4b8f59"
|
||||
down_revision = "1af9b942657b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
# Inspecting the 'tag' table constraints and structure
|
||||
existing_pk = inspector.get_pk_constraint("tag")
|
||||
unique_constraints = inspector.get_unique_constraints("tag")
|
||||
existing_indexes = inspector.get_indexes("tag")
|
||||
|
||||
print(existing_pk, unique_constraints)
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
# Drop unique constraints that could conflict with new primary key
|
||||
for constraint in unique_constraints:
|
||||
if constraint["name"] == "uq_id_user_id":
|
||||
batch_op.drop_constraint(constraint["name"], type_="unique")
|
||||
|
||||
for index in existing_indexes:
|
||||
if index["unique"]:
|
||||
# Drop the unique index
|
||||
batch_op.drop_index(index["name"])
|
||||
|
||||
# Drop existing primary key constraint if it exists
|
||||
if existing_pk and existing_pk.get("constrained_columns"):
|
||||
batch_op.drop_constraint(existing_pk["name"], type_="primary")
|
||||
|
||||
# Immediately after dropping the old primary key, create the new one
|
||||
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
current_pk = inspector.get_pk_constraint("tag")
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
# Drop the current primary key first, if it matches the one we know we added in upgrade
|
||||
if current_pk and "pk_id_user_id" == current_pk.get("name"):
|
||||
batch_op.drop_constraint("pk_id_user_id", type_="primary")
|
||||
|
||||
# Restore the original primary key
|
||||
batch_op.create_primary_key("pk_id", ["id"])
|
||||
|
||||
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
|
||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
Loading…
Reference in New Issue
Block a user