diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 6f1511cd1..cb19d83e8 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -6,6 +6,8 @@ from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON +from sqlalchemy.orm import relationship + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -32,6 +34,8 @@ class File(Base): created_at = Column(BigInteger) updated_at = Column(BigInteger) + users = relationship("User", secondary="shared_file_owner", back_populates="files") + class FileModel(BaseModel): model_config = ConfigDict(from_attributes=True) diff --git a/backend/open_webui/models/shared_file_owner.py b/backend/open_webui/models/shared_file_owner.py new file mode 100644 index 000000000..ef6af0be3 --- /dev/null +++ b/backend/open_webui/models/shared_file_owner.py @@ -0,0 +1,50 @@ +import logging + +from sqlalchemy import exists + +from open_webui.env import SRC_LOG_LEVELS +from open_webui.internal.db import get_db +from open_webui.models.users import shared_file_owner + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +class SharedFileOwner: + + @staticmethod + def has_access(access_type: str, file_id: str, user_id: str) -> bool: + """Checks if a user-file link exists in the database.""" + if access_type != "read": + return False + + with get_db() as db: + try: + stmt = exists().where( + shared_file_owner.c.user_id == user_id, + shared_file_owner.c.file_id == file_id + ).select() + + result = db.execute(stmt).scalar() + return result + except Exception as e: + log.error(f"An error occurred: {e}") + return False + + @staticmethod + def add_shared_file_owner(user_id: str, chat: dict) -> None: + file_ids = [] + for file in chat.get("files", []): + file_id = file.get("id") + if not file_id: + continue + file_ids.append(file_id) + with get_db() as db: + from sqlalchemy import text + files_to_add = db.execute(text("SELECT id FROM file WHERE id IN :file_ids"), {"file_ids": tuple(file_ids)}).fetchall() + existing_file_ids = db.execute(text("SELECT file_id FROM shared_file_owner WHERE user_id = :user_id"), {"user_id": user_id}).fetchall() + existing_file_ids = {f[0] for f in existing_file_ids} + for file_id in files_to_add: + file_id = file_id[0] + if file_id not in existing_file_ids: + db.execute(shared_file_owner.insert().values(user_id=user_id, file_id=file_id)) + db.commit() \ No newline at end of file diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 00d504088..0052696b5 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -9,7 +9,8 @@ from open_webui.models.groups import Groups from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, Table, ForeignKey, Index +from sqlalchemy.orm import relationship from sqlalchemy import or_ @@ -17,6 +18,14 @@ from sqlalchemy import or_ # User DB Schema #################### +shared_file_owner = Table( + "shared_file_owner", + Base.metadata, + Column("user_id", ForeignKey("user.id"), primary_key=True), + Column("file_id", ForeignKey("file.id"), primary_key=True), + Index("idx_shared_file_owner_user_id", "user_id"), + Index("idx_shared_file_owner_file_id", "file_id"), +) class User(Base): __tablename__ = "user" @@ -37,6 +46,8 @@ class User(Base): oauth_sub = Column(Text, unique=True) + files = relationship("File", secondary=shared_file_owner, back_populates="users") + class UserSettings(BaseModel): ui: Optional[dict] = {} diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 29b12ed67..0a4e8f17f 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -11,6 +11,8 @@ from open_webui.models.chats import ( Chats, ChatTitleIdResponse, ) +from open_webui.models.shared_file_owner import SharedFileOwner + from open_webui.models.tags import TagModel, Tags from open_webui.models.folders import Folders @@ -21,6 +23,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel + from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_permission @@ -610,6 +613,8 @@ async def clone_chat_by_id( } chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) + SharedFileOwner.add_shared_file_owner(user.id, chat.chat) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index bdf5780fc..001921340 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -21,6 +21,7 @@ from fastapi import ( from fastapi.responses import FileResponse, StreamingResponse from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS +from open_webui.models.shared_file_owner import SharedFileOwner from open_webui.models.users import Users from open_webui.models.files import ( @@ -63,6 +64,10 @@ def has_access_to_file( ) has_access = False + + if SharedFileOwner.has_access(access_type, file_id, user.id): + return True + knowledge_base_id = file.meta.get("collection_name") if file.meta else None if knowledge_base_id: