diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index e1604d126..5381daf0c 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -35,7 +35,6 @@ from open_webui.utils.plugin import ( get_function_module_from_cache, ) from open_webui.utils.tools import get_tools -from open_webui.utils.access_control import has_access from open_webui.env import GLOBAL_LOG_LEVEL diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5fab6990d..7d6ae9cba 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -516,7 +516,6 @@ from open_webui.utils.middleware import ( process_chat_payload, process_chat_response, ) -from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( get_license_data, diff --git a/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py new file mode 100644 index 000000000..1b76e67c3 --- /dev/null +++ b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py @@ -0,0 +1,350 @@ +"""Add access_grant table + +Revision ID: f1e2d3c4b5a6 +Revises: 8452d01d26d7 +Create Date: 2026-02-05 10:00:00.000000 + +Migrates from JSON access_control columns to normalized access_grant table. +Access control semantics: +- NULL: Public access (all users can read) -> insert user:* for read +- {}: Private/owner-only (no grants) -> insert nothing +- {read: {...}, write: {...}}: Custom permissions -> insert specific grants +""" + +from typing import Sequence, Union +import time +import uuid + +from alembic import op +import sqlalchemy as sa + +from open_webui.migrations.util import get_existing_tables + +revision: str = "f1e2d3c4b5a6" +down_revision: Union[str, None] = "8452d01d26d7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + existing_tables = set(get_existing_tables()) + + # Create access_grant table + if "access_grant" not in existing_tables: + op.create_table( + "access_grant", + sa.Column("id", sa.Text(), nullable=False, primary_key=True), + sa.Column("resource_type", sa.Text(), nullable=False), + sa.Column("resource_id", sa.Text(), nullable=False), + sa.Column("principal_type", sa.Text(), nullable=False), + sa.Column("principal_id", sa.Text(), nullable=False), + sa.Column("permission", sa.Text(), nullable=False), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.UniqueConstraint( + "resource_type", + "resource_id", + "principal_type", + "principal_id", + "permission", + name="uq_access_grant_grant", + ), + ) + op.create_index( + "idx_access_grant_resource", + "access_grant", + ["resource_type", "resource_id"], + ) + op.create_index( + "idx_access_grant_principal", + "access_grant", + ["principal_type", "principal_id"], + ) + + # Backfill existing access_control JSON data + conn = op.get_bind() + + # Tables with access_control JSON columns: (table_name, resource_type) + resource_tables = [ + ("knowledge", "knowledge"), + ("prompt", "prompt"), + ("tool", "tool"), + ("model", "model"), + ("note", "note"), + ("channel", "channel"), + ("file", "file"), + ] + + now = int(time.time()) + inserted = set() + + for table_name, resource_type in resource_tables: + if table_name not in existing_tables: + continue + + # Query all rows + try: + result = conn.execute( + sa.text(f'SELECT id, access_control FROM "{table_name}"') + ) + rows = result.fetchall() + except Exception: + continue + + for row in rows: + resource_id = row[0] + access_control_json = row[1] + + # Handle NULL or JSON "null" = public access (user:* for read) + # Could be Python None (SQL NULL) or string "null" (JSON null) + # EXCEPTION: files with NULL are PRIVATE (owner-only), not public + is_null = ( + access_control_json is None or + access_control_json == "null" or + (isinstance(access_control_json, str) and access_control_json.strip().lower() == "null") + ) + if is_null: + # Files: NULL = private (no entry needed, owner has implicit access) + # Other resources: NULL = public (insert user:* for read) + if resource_type == "file": + continue # Private - no entry needed + + key = (resource_type, resource_id, "user", "*", "read") + if key not in inserted: + try: + conn.execute( + sa.text( + """ + INSERT INTO access_grant (id, resource_type, resource_id, principal_type, principal_id, permission, created_at) + VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) + """ + ), + { + "id": str(uuid.uuid4()), + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": "*", + "permission": "read", + "created_at": now, + }, + ) + inserted.add(key) + except Exception: + pass + continue + + # Handle JSON parsing + if isinstance(access_control_json, str): + import json + + try: + access_control_json = json.loads(access_control_json) + except Exception: + continue + + # Handle {} = private/owner-only - NO entries needed + # Owner access is implicit, no grants to store + if not access_control_json or not isinstance(access_control_json, dict): + continue + + # Check if it's effectively empty (no read/write keys with content) + read_data = access_control_json.get("read", {}) + write_data = access_control_json.get("write", {}) + + has_read_grants = read_data.get("group_ids", []) or read_data.get( + "user_ids", [] + ) + has_write_grants = write_data.get("group_ids", []) or write_data.get( + "user_ids", [] + ) + + if not has_read_grants and not has_write_grants: + # Empty permissions = private, no grants needed + continue + + # Extract permissions and insert into access_grant table + for permission in ["read", "write"]: + perm_data = access_control_json.get(permission, {}) + if not perm_data: + continue + + for group_id in perm_data.get("group_ids", []): + key = (resource_type, resource_id, "group", group_id, permission) + if key in inserted: + continue + try: + conn.execute( + sa.text( + """ + INSERT INTO access_grant (id, resource_type, resource_id, principal_type, principal_id, permission, created_at) + VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) + """ + ), + { + "id": str(uuid.uuid4()), + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "group", + "principal_id": group_id, + "permission": permission, + "created_at": now, + }, + ) + inserted.add(key) + except Exception: + pass + + for user_id in perm_data.get("user_ids", []): + key = (resource_type, resource_id, "user", user_id, permission) + if key in inserted: + continue + try: + conn.execute( + sa.text( + """ + INSERT INTO access_grant (id, resource_type, resource_id, principal_type, principal_id, permission, created_at) + VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) + """ + ), + { + "id": str(uuid.uuid4()), + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": user_id, + "permission": permission, + "created_at": now, + }, + ) + inserted.add(key) + except Exception: + pass + + # Drop access_control columns from resource tables + for table_name, _ in resource_tables: + if table_name not in existing_tables: + continue + try: + with op.batch_alter_table(table_name) as batch: + batch.drop_column("access_control") + except Exception: + pass + + +def downgrade() -> None: + import json + + conn = op.get_bind() + + # Resource tables mapping: (table_name, resource_type) + resource_tables = [ + ("knowledge", "knowledge"), + ("prompt", "prompt"), + ("tool", "tool"), + ("model", "model"), + ("note", "note"), + ("channel", "channel"), + ("file", "file"), + ] + + # Step 1: Re-add access_control columns to resource tables + for table_name, _ in resource_tables: + try: + with op.batch_alter_table(table_name) as batch: + batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True)) + except Exception: + pass + + # Step 2: Query access_grant table and reconstruct JSON for each resource + for table_name, resource_type in resource_tables: + try: + # Get all grants for this resource type + result = conn.execute( + sa.text(""" + SELECT resource_id, principal_type, principal_id, permission + FROM access_grant + WHERE resource_type = :resource_type + """), + {"resource_type": resource_type} + ) + rows = result.fetchall() + except Exception: + continue + + # Group by resource_id and reconstruct JSON structure + resource_grants = {} + for row in rows: + resource_id = row[0] + principal_type = row[1] + principal_id = row[2] + permission = row[3] + + if resource_id not in resource_grants: + resource_grants[resource_id] = { + "is_public": False, + "read": {"group_ids": [], "user_ids": []}, + "write": {"group_ids": [], "user_ids": []}, + } + + # Handle public access (user:* for read) + if principal_type == "user" and principal_id == "*" and permission == "read": + resource_grants[resource_id]["is_public"] = True + continue + + # Add to appropriate list + if permission in ["read", "write"]: + if principal_type == "group": + if principal_id not in resource_grants[resource_id][permission]["group_ids"]: + resource_grants[resource_id][permission]["group_ids"].append(principal_id) + elif principal_type == "user": + if principal_id not in resource_grants[resource_id][permission]["user_ids"]: + resource_grants[resource_id][permission]["user_ids"].append(principal_id) + + # Step 3: Update each resource with reconstructed JSON + for resource_id, grants in resource_grants.items(): + if grants["is_public"]: + # Public = NULL + access_control_value = None + elif (not grants["read"]["group_ids"] and not grants["read"]["user_ids"] and + not grants["write"]["group_ids"] and not grants["write"]["user_ids"]): + # No grants = should not happen (would mean no entries), default to {} + access_control_value = json.dumps({}) + else: + # Custom permissions + access_control_value = json.dumps({ + "read": grants["read"], + "write": grants["write"], + }) + + try: + conn.execute( + sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'), + {"access_control": access_control_value, "id": resource_id} + ) + except Exception: + pass + + # Step 4: Set all resources WITHOUT entries to private + # For files: NULL means private (owner-only), so leave as NULL + # For other resources: {} means private, so update to {} + if resource_type != "file": + try: + conn.execute( + sa.text(f''' + UPDATE "{table_name}" + SET access_control = :private_value + WHERE id NOT IN ( + SELECT DISTINCT resource_id FROM access_grant WHERE resource_type = :resource_type + ) + AND access_control IS NULL + '''), + {"private_value": json.dumps({}), "resource_type": resource_type} + ) + except Exception: + pass + # For files, NULL stays NULL - no action needed + + # Step 5: Drop the access_grant table + op.drop_index("idx_access_grant_principal", table_name="access_grant") + op.drop_index("idx_access_grant_resource", table_name="access_grant") + op.drop_table("access_grant") diff --git a/backend/open_webui/models/access_grants.py b/backend/open_webui/models/access_grants.py new file mode 100644 index 000000000..aac475f3e --- /dev/null +++ b/backend/open_webui/models/access_grants.py @@ -0,0 +1,776 @@ +import logging +import time +import uuid +from typing import Optional + +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db_context + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, or_, and_ +from sqlalchemy.dialects.postgresql import JSONB + +log = logging.getLogger(__name__) + + +#################### +# AccessGrant DB Schema +#################### + + +class AccessGrant(Base): + __tablename__ = "access_grant" + + id = Column(Text, primary_key=True) + resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file" + resource_id = Column(Text, nullable=False) + principal_type = Column(Text, nullable=False) # "user" or "group" + principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public) + permission = Column(Text, nullable=False) # "read" or "write" + created_at = Column(BigInteger, nullable=False) + + __table_args__ = ( + UniqueConstraint( + "resource_type", + "resource_id", + "principal_type", + "principal_id", + "permission", + name="uq_access_grant_grant", + ), + ) + + +class AccessGrantModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + resource_type: str + resource_id: str + principal_type: str + principal_id: str + permission: str + created_at: int + + +class AccessGrantResponse(BaseModel): + """Slim grant model for API responses — resource context is implicit from the parent.""" + + id: str + principal_type: str + principal_id: str + permission: str + + @classmethod + def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse": + return cls( + id=grant.id, + principal_type=grant.principal_type, + principal_id=grant.principal_id, + permission=grant.permission, + ) + + +#################### +# Conversion utilities +#################### + + +def access_control_to_grants( + resource_type: str, + resource_id: str, + access_control: Optional[dict], +) -> list[dict]: + """ + Convert an old-style access_control JSON dict to a flat list of grant dicts. + + Semantics: + - None → public read (user:* read) — except files which are private + - {} → private/owner-only (no grants) + - {read: {group_ids, user_ids}, write: {group_ids, user_ids}} → specific grants + + Returns a list of dicts with keys: resource_type, resource_id, principal_type, principal_id, permission + """ + grants = [] + + if access_control is None: + # NULL → public read (user:* for read) + # Exception: files with NULL are private (owner-only), no grants needed + if resource_type != "file": + grants.append( + { + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": "*", + "permission": "read", + } + ) + return grants + + # {} → private/owner-only, no grants + if not access_control: + return grants + + # Parse structured permissions + for permission in ["read", "write"]: + perm_data = access_control.get(permission, {}) + if not perm_data: + continue + + for group_id in perm_data.get("group_ids", []): + grants.append( + { + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "group", + "principal_id": group_id, + "permission": permission, + } + ) + + for user_id in perm_data.get("user_ids", []): + grants.append( + { + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": user_id, + "permission": permission, + } + ) + + return grants + + +def normalize_access_grants(access_grants: Optional[list]) -> list[dict]: + """ + Normalize direct access_grants payloads from API forms. + + Keeps only valid grants and removes duplicates by + (principal_type, principal_id, permission). + """ + if not access_grants: + return [] + + deduped = {} + for grant in access_grants: + if isinstance(grant, BaseModel): + grant = grant.model_dump() + if not isinstance(grant, dict): + continue + + principal_type = grant.get("principal_type") + principal_id = grant.get("principal_id") + permission = grant.get("permission") + + if principal_type not in ("user", "group"): + continue + if permission not in ("read", "write"): + continue + if not isinstance(principal_id, str) or not principal_id: + continue + + key = (principal_type, principal_id, permission) + deduped[key] = { + "id": grant.get("id") + if isinstance(grant.get("id"), str) and grant.get("id") + else str(uuid.uuid4()), + "principal_type": principal_type, + "principal_id": principal_id, + "permission": permission, + } + + return list(deduped.values()) + + +def has_public_read_access_grant(access_grants: Optional[list]) -> bool: + """ + Returns True when a direct grant list includes wildcard public-read. + """ + for grant in normalize_access_grants(access_grants): + if ( + grant["principal_type"] == "user" + and grant["principal_id"] == "*" + and grant["permission"] == "read" + ): + return True + return False + + +def grants_to_access_control(grants: list) -> Optional[dict]: + """ + Convert a list of grant objects (AccessGrantModel or AccessGrantResponse) + back to the old-style access_control JSON dict for backward compatibility. + + Semantics: + - [] (empty) → {} (private/owner-only) + - Contains user:*:read → None (public), but write grants are preserved + - Otherwise → {read: {group_ids, user_ids}, write: {group_ids, user_ids}} + + Note: "public" (user:*:read) still allows additional write permissions + to coexist. When the wildcard read is present the function returns None + for the legacy dict, so callers that need write info should inspect the + grants list directly. + """ + if not grants: + return {} # No grants = private/owner-only + + result = { + "read": {"group_ids": [], "user_ids": []}, + "write": {"group_ids": [], "user_ids": []}, + } + + is_public = False + for grant in grants: + if ( + grant.principal_type == "user" + and grant.principal_id == "*" + and grant.permission == "read" + ): + is_public = True + continue # Don't add wildcard to user_ids list + + if grant.permission not in ("read", "write"): + continue + + if grant.principal_type == "group": + if grant.principal_id not in result[grant.permission]["group_ids"]: + result[grant.permission]["group_ids"].append(grant.principal_id) + elif grant.principal_type == "user": + if grant.principal_id not in result[grant.permission]["user_ids"]: + result[grant.permission]["user_ids"].append(grant.principal_id) + + if is_public: + return None # Public read access + + return result + + +#################### +# Table Operations +#################### + + +class AccessGrantsTable: + def grant_access( + self, + resource_type: str, + resource_id: str, + principal_type: str, + principal_id: str, + permission: str, + db: Optional[Session] = None, + ) -> Optional[AccessGrantModel]: + """Add a single access grant. Idempotent (ignores duplicates).""" + with get_db_context(db) as db: + # Check for existing grant + existing = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + principal_type=principal_type, + principal_id=principal_id, + permission=permission, + ) + .first() + ) + if existing: + return AccessGrantModel.model_validate(existing) + + grant = AccessGrant( + id=str(uuid.uuid4()), + resource_type=resource_type, + resource_id=resource_id, + principal_type=principal_type, + principal_id=principal_id, + permission=permission, + created_at=int(time.time()), + ) + db.add(grant) + db.commit() + db.refresh(grant) + return AccessGrantModel.model_validate(grant) + + def revoke_access( + self, + resource_type: str, + resource_id: str, + principal_type: str, + principal_id: str, + permission: str, + db: Optional[Session] = None, + ) -> bool: + """Remove a single access grant.""" + with get_db_context(db) as db: + deleted = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + principal_type=principal_type, + principal_id=principal_id, + permission=permission, + ) + .delete() + ) + db.commit() + return deleted > 0 + + def revoke_all_access( + self, + resource_type: str, + resource_id: str, + db: Optional[Session] = None, + ) -> int: + """Remove all access grants for a resource.""" + with get_db_context(db) as db: + deleted = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + ) + .delete() + ) + db.commit() + return deleted + + def set_access_control( + self, + resource_type: str, + resource_id: str, + access_control: Optional[dict], + db: Optional[Session] = None, + ) -> list[AccessGrantModel]: + """ + Replace all grants for a resource from an access_control JSON dict. + This is the primary bridge for backward compat with the frontend. + """ + with get_db_context(db) as db: + # Delete all existing grants for this resource + db.query(AccessGrant).filter_by( + resource_type=resource_type, + resource_id=resource_id, + ).delete() + + # Convert JSON to grant dicts + grant_dicts = access_control_to_grants( + resource_type, resource_id, access_control + ) + + # Insert new grants + results = [] + for grant_dict in grant_dicts: + grant = AccessGrant( + id=str(uuid.uuid4()), + **grant_dict, + created_at=int(time.time()), + ) + db.add(grant) + results.append(grant) + + db.commit() + + return [AccessGrantModel.model_validate(g) for g in results] + + def set_access_grants( + self, + resource_type: str, + resource_id: str, + access_grants: Optional[list], + db: Optional[Session] = None, + ) -> list[AccessGrantModel]: + """ + Replace all grants for a resource from a direct access_grants list. + """ + with get_db_context(db) as db: + db.query(AccessGrant).filter_by( + resource_type=resource_type, + resource_id=resource_id, + ).delete() + + normalized_grants = normalize_access_grants(access_grants) + + results = [] + for grant_dict in normalized_grants: + grant = AccessGrant( + id=grant_dict["id"], + resource_type=resource_type, + resource_id=resource_id, + principal_type=grant_dict["principal_type"], + principal_id=grant_dict["principal_id"], + permission=grant_dict["permission"], + created_at=int(time.time()), + ) + db.add(grant) + results.append(grant) + + db.commit() + return [AccessGrantModel.model_validate(g) for g in results] + + def get_access_control( + self, + resource_type: str, + resource_id: str, + db: Optional[Session] = None, + ) -> Optional[dict]: + """ + Reconstruct the old-style access_control JSON dict from grants. + For backward compat with the frontend. + """ + with get_db_context(db) as db: + grants = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + ) + .all() + ) + grant_models = [AccessGrantModel.model_validate(g) for g in grants] + return grants_to_access_control(grant_models) + + def get_grants_by_resource( + self, + resource_type: str, + resource_id: str, + db: Optional[Session] = None, + ) -> list[AccessGrantModel]: + """Get all grants for a specific resource.""" + with get_db_context(db) as db: + grants = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + ) + .all() + ) + return [AccessGrantModel.model_validate(g) for g in grants] + + def has_access( + self, + user_id: str, + resource_type: str, + resource_id: str, + permission: str = "read", + user_group_ids: Optional[set[str]] = None, + db: Optional[Session] = None, + ) -> bool: + """ + Check if a user has the specified permission on a resource. + + Access is granted if any of the following is true: + - There's a grant for user:* (public) with the requested permission + - There's a grant for the specific user with the requested permission + - There's a grant for any of the user's groups with the requested permission + """ + with get_db_context(db) as db: + # Build conditions for matching grants + conditions = [ + # Public access + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ), + # Direct user access + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ), + ] + + # Group access + if user_group_ids is None: + from open_webui.models.groups import Groups + + user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_group_ids = {group.id for group in user_groups} + + if user_group_ids: + conditions.append( + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(user_group_ids), + ) + ) + + exists = ( + db.query(AccessGrant) + .filter( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == resource_id, + AccessGrant.permission == permission, + or_(*conditions), + ) + .first() + ) + return exists is not None + + def get_users_with_access( + self, + resource_type: str, + resource_id: str, + permission: str = "read", + db: Optional[Session] = None, + ) -> list: + """ + Get all users who have the specified permission on a resource. + Returns a list of UserModel instances. + """ + from open_webui.models.users import Users, UserModel + from open_webui.models.groups import Groups + + with get_db_context(db) as db: + grants = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + permission=permission, + ) + .all() + ) + + # Check for public access + for grant in grants: + if grant.principal_type == "user" and grant.principal_id == "*": + result = Users.get_users(filter={"roles": ["!pending"]}, db=db) + return result.get("users", []) + + user_ids_with_access = set() + + for grant in grants: + if grant.principal_type == "user": + user_ids_with_access.add(grant.principal_id) + elif grant.principal_type == "group": + group_user_ids = Groups.get_group_user_ids_by_id( + grant.principal_id, db=db + ) + if group_user_ids: + user_ids_with_access.update(group_user_ids) + + if not user_ids_with_access: + return [] + + return Users.get_users_by_user_ids(list(user_ids_with_access), db=db) + + def has_permission_filter( + self, + db, + query, + DocumentModel, + filter: dict, + resource_type: str, + permission: str = "read", + ): + """ + Apply access control filtering to a SQLAlchemy query by JOINing with access_grant. + + This replaces the old JSON-column-based filtering with a proper relational JOIN. + """ + group_ids = filter.get("group_ids", []) + user_id = filter.get("user_id") + + if permission == "read_only": + return self._has_read_only_permission_filter( + db, query, DocumentModel, filter, resource_type + ) + + # Build principal conditions + principal_conditions = [] + + if group_ids or user_id: + # Public access: user:* read + principal_conditions.append( + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ) + ) + + if user_id: + # Owner always has access + principal_conditions.append(DocumentModel.user_id == user_id) + + # Direct user grant + principal_conditions.append( + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ) + + if group_ids: + # Group grants + principal_conditions.append( + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ) + + if not principal_conditions: + return query + + # LEFT JOIN access_grant and filter + # We use a subquery approach to avoid duplicates from multiple matching grants + from sqlalchemy import exists as sa_exists, select + + grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == permission, + or_( + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ), + *( + [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ] + if user_id + else [] + ), + *( + [ + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ] + if group_ids + else [] + ), + ), + ) + .correlate(DocumentModel) + .exists() + ) + + # Owner OR has a matching grant + owner_or_grant = [grant_exists] + if user_id: + owner_or_grant.append(DocumentModel.user_id == user_id) + + query = query.filter(or_(*owner_or_grant)) + return query + + def _has_read_only_permission_filter( + self, + db, + query, + DocumentModel, + filter: dict, + resource_type: str, + ): + """ + Filter for items where user has read BUT NOT write access. + Public items are NOT considered read_only. + """ + group_ids = filter.get("group_ids", []) + user_id = filter.get("user_id") + + from sqlalchemy import exists as sa_exists, select + + # Has read grant (not public) + read_grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == "read", + or_( + *( + [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ] + if user_id + else [] + ), + *( + [ + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ] + if group_ids + else [] + ), + ), + ) + .correlate(DocumentModel) + .exists() + ) + + # Does NOT have write grant + write_grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == "write", + or_( + *( + [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ] + if user_id + else [] + ), + *( + [ + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ] + if group_ids + else [] + ), + ), + ) + .correlate(DocumentModel) + .exists() + ) + + # Is NOT public + public_grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == "read", + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ) + .correlate(DocumentModel) + .exists() + ) + + conditions = [read_grant_exists, ~write_grant_exists, ~public_grant_exists] + + # Not owner + if user_id: + conditions.append(DocumentModel.user_id != user_id) + + query = query.filter(and_(*conditions)) + return query + + +AccessGrants = AccessGrantsTable() diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 8e70918e1..3ff6fb755 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -7,8 +7,12 @@ from typing import Optional from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups +from open_webui.models.access_grants import ( + AccessGrantModel, + AccessGrants, +) -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.dialects.postgresql import JSONB @@ -47,7 +51,6 @@ class Channel(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) created_at = Column(BigInteger) @@ -76,7 +79,7 @@ class ChannelModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) created_at: int # timestamp in epoch (time_ns) @@ -237,7 +240,7 @@ class ChannelForm(BaseModel): is_private: Optional[bool] = None data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None group_ids: Optional[list[str]] = None user_ids: Optional[list[str]] = None @@ -252,6 +255,18 @@ class ChannelWebhookForm(BaseModel): class ChannelTable: + def _get_access_grants( + self, channel_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("channel", channel_id, db=db) + + def _to_channel_model( + self, channel: Channel, db: Optional[Session] = None + ) -> ChannelModel: + channel_data = ChannelModel.model_validate(channel).model_dump(exclude={"access_grants"}) + access_grants = self._get_access_grants(channel_data["id"], db=db) + channel_data["access_grants"] = access_grants + return ChannelModel.model_validate(channel_data) def _collect_unique_user_ids( self, @@ -316,16 +331,17 @@ class ChannelTable: with get_db_context(db) as db: channel = ChannelModel( **{ - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "type": form_data.type if form_data.type else None, "name": form_data.name.lower(), "id": str(uuid.uuid4()), "user_id": user_id, "created_at": int(time.time_ns()), "updated_at": int(time.time_ns()), + "access_grants": [], } ) - new_channel = Channel(**channel.model_dump()) + new_channel = Channel(**channel.model_dump(exclude={"access_grants"})) if form_data.type in ["group", "dm"]: users = self._collect_unique_user_ids( @@ -342,54 +358,25 @@ class ChannelTable: db.add_all(memberships) db.add(new_channel) db.commit() - return channel + AccessGrants.set_access_grants( + "channel", new_channel.id, form_data.access_grants, db=db + ) + return self._to_channel_model(new_channel, db=db) def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: channels = db.query(Channel).all() - return [ChannelModel.model_validate(channel) for channel in channels] + return [self._to_channel_model(channel, db=db) for channel in channels] def _has_permission(self, db, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - - dialect_name = db.bind.dialect.name - - # Public access - conditions = [] - if group_ids or user_id: - conditions.extend( - [ - Channel.access_control.is_(None), - cast(Channel.access_control, String) == "null", - ] - ) - - # User-level permission - if user_id: - conditions.append(Channel.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - Channel.access_control[permission]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - Channel.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query + return AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Channel, + filter=filter, + resource_type="channel", + permission=permission, + ) def get_channels_by_user_id( self, user_id: str, db: Optional[Session] = None @@ -428,7 +415,7 @@ class ChannelTable: standard_channels = query.all() all_channels = membership_channels + standard_channels - return [ChannelModel.model_validate(c) for c in all_channels] + return [self._to_channel_model(c, db=db) for c in all_channels] def get_dm_channel_by_user_ids( self, user_ids: list[str], db: Optional[Session] = None @@ -463,7 +450,7 @@ class ChannelTable: .first() ) - return ChannelModel.model_validate(channel) if channel else None + return self._to_channel_model(channel, db=db) if channel else None def add_members_to_channel( self, @@ -722,7 +709,7 @@ class ChannelTable: try: with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() - return ChannelModel.model_validate(channel) if channel else None + return self._to_channel_model(channel, db=db) if channel else None except Exception: return None @@ -735,7 +722,7 @@ class ChannelTable: ) channel_ids = [cf.channel_id for cf in channel_files] channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all() - return [ChannelModel.model_validate(channel) for channel in channels] + return [self._to_channel_model(channel, db=db) for channel in channels] def get_channels_by_file_id_and_user_id( self, file_id: str, user_id: str, db: Optional[Session] = None @@ -783,7 +770,9 @@ class ChannelTable: .first() ) if membership: - allowed_channels.append(ChannelModel.model_validate(channel)) + allowed_channels.append( + self._to_channel_model(channel, db=db) + ) continue # --- Case B: standard channel => rely on ACL permissions --- @@ -798,7 +787,7 @@ class ChannelTable: allowed = query.first() if allowed: - allowed_channels.append(ChannelModel.model_validate(allowed)) + allowed_channels.append(self._to_channel_model(allowed, db=db)) return allowed_channels @@ -832,7 +821,7 @@ class ChannelTable: .first() ) if membership: - return ChannelModel.model_validate(channel) + return self._to_channel_model(channel, db=db) else: return None @@ -854,7 +843,7 @@ class ChannelTable: channel_allowed = query.first() return ( - ChannelModel.model_validate(channel_allowed) + self._to_channel_model(channel_allowed, db=db) if channel_allowed else None ) @@ -874,11 +863,14 @@ class ChannelTable: channel.data = form_data.data channel.meta = form_data.meta - channel.access_control = form_data.access_control + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "channel", id, form_data.access_grants, db=db + ) channel.updated_at = int(time.time_ns()) db.commit() - return ChannelModel.model_validate(channel) if channel else None + return self._to_channel_model(channel, db=db) if channel else None def add_file_to_channel_by_id( self, channel_id: str, file_id: str, user_id: str, db: Optional[Session] = None @@ -947,6 +939,7 @@ class ChannelTable: def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: + AccessGrants.revoke_all_access("channel", id, db=db) db.query(Channel).filter(Channel.id == id).delete() db.commit() return True diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index c24b242bd..67f289160 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -26,8 +26,6 @@ class File(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) - created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -45,8 +43,6 @@ class FileModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None - created_at: Optional[int] # timestamp in epoch updated_at: Optional[int] # timestamp in epoch @@ -113,7 +109,6 @@ class FileForm(BaseModel): path: str data: dict = {} meta: dict = {} - access_control: Optional[dict] = None class FileUpdateForm(BaseModel): diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index eab817534..0859c053a 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -22,6 +22,7 @@ from sqlalchemy import ( ForeignKey, cast, or_, + select, ) @@ -99,6 +100,16 @@ class GroupResponse(GroupModel): member_count: Optional[int] = None +class GroupInfoResponse(BaseModel): + id: str + user_id: str + name: str + description: str + member_count: Optional[int] = None + created_at: int + updated_at: int + + class GroupForm(BaseModel): name: str description: str @@ -171,22 +182,22 @@ class GroupTable: if share_value: # Groups open to anyone: data is null, config.share is null, or share is true # Use case-insensitive string comparison to handle variations like "True", "TRUE" + # Handle potential JSON boolean to string casting issues by checking for both string 'true' and boolean equivalence if possible, anyone_can_share = or_( Group.data.is_(None), json_share_str.is_(None), json_share_lower == "true", + json_share_lower == "1", # Handle SQLite boolean true ) if member_id: # Also include member-only groups where user is a member - member_groups_subq = ( - db.query(GroupMember.group_id) - .filter(GroupMember.user_id == member_id) - .subquery() + member_groups_select = select(GroupMember.group_id).where( + GroupMember.user_id == member_id ) members_only_and_is_member = and_( json_share_lower == "members", - Group.id.in_(member_groups_subq), + Group.id.in_(member_groups_select), ) query = query.filter( or_(anyone_can_share, members_only_and_is_member) @@ -305,14 +316,14 @@ class GroupTable: def get_group_user_ids_by_id( self, id: str, db: Optional[Session] = None - ) -> Optional[list[str]]: + ) -> list[str]: with get_db_context(db) as db: members = ( db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() ) if not members: - return None + return [] return [m[0] for m in members] diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index 81aa4099d..817cab5ca 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -15,9 +15,10 @@ from open_webui.models.files import ( ) from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, Users, UserResponse +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import ( BigInteger, Column, @@ -29,9 +30,6 @@ from sqlalchemy import ( or_, ) -from open_webui.utils.access_control import has_access -from open_webui.utils.db.access_control import has_permission - log = logging.getLogger(__name__) @@ -50,22 +48,6 @@ class Knowledge(Base): description = Column(Text) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -82,7 +64,7 @@ class KnowledgeModel(BaseModel): meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -139,7 +121,7 @@ class KnowledgeUserResponse(KnowledgeUserModel): class KnowledgeForm(BaseModel): name: str description: str - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class FileUserResponse(FileModelResponse): @@ -157,27 +139,47 @@ class KnowledgeFileListResponse(BaseModel): class KnowledgeTable: + def _get_access_grants( + self, knowledge_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db) + + def _to_knowledge_model( + self, knowledge: Knowledge, db: Optional[Session] = None + ) -> KnowledgeModel: + knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump( + exclude={"access_grants"} + ) + knowledge_data["access_grants"] = self._get_access_grants( + knowledge_data["id"], db=db + ) + return KnowledgeModel.model_validate(knowledge_data) + def insert_new_knowledge( self, user_id: str, form_data: KnowledgeForm, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: with get_db_context(db) as db: knowledge = KnowledgeModel( **{ - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "id": str(uuid.uuid4()), "user_id": user_id, "created_at": int(time.time()), "updated_at": int(time.time()), + "access_grants": [], } ) try: - result = Knowledge(**knowledge.model_dump()) + result = Knowledge(**knowledge.model_dump(exclude={"access_grants"})) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "knowledge", result.id, form_data.access_grants, db=db + ) if result: - return KnowledgeModel.model_validate(result) + return self._to_knowledge_model(result, db=db) else: return None except Exception: @@ -201,7 +203,7 @@ class KnowledgeTable: knowledge_bases.append( KnowledgeUserModel.model_validate( { - **KnowledgeModel.model_validate(knowledge).model_dump(), + **self._to_knowledge_model(knowledge, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -241,7 +243,14 @@ class KnowledgeTable: elif view_option == "shared": query = query.filter(Knowledge.user_id != user_id) - query = has_permission(db, Knowledge, query, filter) + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Knowledge, + filter=filter, + resource_type="knowledge", + permission="read", + ) query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc()) @@ -258,8 +267,8 @@ class KnowledgeTable: knowledge_bases.append( KnowledgeUserModel.model_validate( { - **KnowledgeModel.model_validate( - knowledge_base + **self._to_knowledge_model( + knowledge_base, db=db ).model_dump(), "user": ( UserModel.model_validate(user).model_dump() @@ -294,7 +303,14 @@ class KnowledgeTable: # Apply access-control directly to the joined query # This makes the database handle filtering, even with 10k+ KBs - query = has_permission(db, Knowledge, query, filter) + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Knowledge, + filter=filter, + resource_type="knowledge", + permission="read", + ) # Apply filename search if filter: @@ -327,8 +343,8 @@ class KnowledgeTable: if user else None ), - collection=KnowledgeModel.model_validate( - knowledge + collection=self._to_knowledge_model( + knowledge, db=db ).model_dump(), ) ) @@ -350,7 +366,14 @@ class KnowledgeTable: user_group_ids = { group.id for group in Groups.get_groups_by_member_id(user_id, db=db) } - return has_access(user_id, permission, knowledge.access_control, user_group_ids) + return AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) def get_knowledge_bases_by_user_id( self, user_id: str, permission: str = "write", db: Optional[Session] = None @@ -363,8 +386,13 @@ class KnowledgeTable: knowledge_base for knowledge_base in knowledge_bases if knowledge_base.user_id == user_id - or has_access( - user_id, permission, knowledge_base.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, ) ] @@ -374,7 +402,9 @@ class KnowledgeTable: try: with get_db_context(db) as db: knowledge = db.query(Knowledge).filter_by(id=id).first() - return KnowledgeModel.model_validate(knowledge) if knowledge else None + return ( + self._to_knowledge_model(knowledge, db=db) if knowledge else None + ) except Exception: return None @@ -391,7 +421,14 @@ class KnowledgeTable: user_group_ids = { group.id for group in Groups.get_groups_by_member_id(user_id, db=db) } - if has_access(user_id, "write", knowledge.access_control, user_group_ids): + if AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + user_group_ids=user_group_ids, + db=db, + ): return knowledge return None @@ -406,9 +443,7 @@ class KnowledgeTable: .filter(KnowledgeFile.file_id == file_id) .all() ) - return [ - KnowledgeModel.model_validate(knowledge) for knowledge in knowledges - ] + return [self._to_knowledge_model(knowledge, db=db) for knowledge in knowledges] except Exception: return [] @@ -591,11 +626,15 @@ class KnowledgeTable: knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "updated_at": int(time.time()), } ) db.commit() + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "knowledge", id, form_data.access_grants, db=db + ) return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) @@ -622,6 +661,7 @@ class KnowledgeTable: def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("knowledge", id, db=db) db.query(Knowledge).filter_by(id=id).delete() db.commit() return True @@ -631,6 +671,9 @@ class KnowledgeTable: def delete_all_knowledge(self, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: + knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()] + for knowledge_id in knowledge_ids: + AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db) db.query(Knowledge).delete() db.commit() diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 5a59861dd..d523ae0fc 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -7,18 +7,16 @@ from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, Users, UserResponse +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import String, cast, or_, and_, func from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy import BigInteger, Column, Text, JSON, Boolean - - -from open_webui.utils.access_control import has_access +from sqlalchemy import BigInteger, Column, Text, Boolean log = logging.getLogger(__name__) @@ -80,23 +78,6 @@ class Model(Base): Holds a JSON encoded blob of metadata, see `ModelMeta`. """ - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } - is_active = Column(Boolean, default=True) updated_at = Column(BigInteger) @@ -112,7 +93,7 @@ class ModelModel(BaseModel): params: ModelParams meta: ModelMeta - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) is_active: bool updated_at: int # timestamp in epoch @@ -154,31 +135,45 @@ class ModelForm(BaseModel): name: str meta: ModelMeta params: ModelParams - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None is_active: bool = True class ModelsTable: + def _get_access_grants( + self, model_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("model", model_id, db=db) + + def _to_model_model(self, model: Model, db: Optional[Session] = None) -> ModelModel: + model_data = ModelModel.model_validate(model).model_dump( + exclude={"access_grants"} + ) + model_data["access_grants"] = self._get_access_grants(model_data["id"], db=db) + return ModelModel.model_validate(model_data) + def insert_new_model( self, form_data: ModelForm, user_id: str, db: Optional[Session] = None ) -> Optional[ModelModel]: - model = ModelModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) try: with get_db_context(db) as db: - result = Model(**model.model_dump()) + result = Model( + **{ + **form_data.model_dump(exclude={"access_grants"}), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "model", result.id, form_data.access_grants, db=db + ) if result: - return ModelModel.model_validate(result) + return self._to_model_model(result, db=db) else: return None except Exception as e: @@ -187,7 +182,7 @@ class ModelsTable: def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] + return [self._to_model_model(model, db=db) for model in db.query(Model).all()] def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: with get_db_context(db) as db: @@ -204,7 +199,7 @@ class ModelsTable: models.append( ModelUserResponse.model_validate( { - **ModelModel.model_validate(model).model_dump(), + **self._to_model_model(model, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -214,7 +209,7 @@ class ModelsTable: def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: return [ - ModelModel.model_validate(model) + self._to_model_model(model, db=db) for model in db.query(Model).filter(Model.base_model_id == None).all() ] @@ -229,50 +224,25 @@ class ModelsTable: model for model in models if model.user_id == user_id - or has_access(user_id, permission, model.access_control, user_group_ids) + or AccessGrants.has_access( + user_id=user_id, + resource_type="model", + resource_id=model.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) ] def _has_permission(self, db, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - - dialect_name = db.bind.dialect.name - - # Public access - conditions = [] - if group_ids or user_id: - conditions.extend( - [ - Model.access_control.is_(None), - cast(Model.access_control, String) == "null", - ] - ) - - # User-level permission - if user_id: - conditions.append(Model.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - Model.access_control[permission]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - Model.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query + return AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Model, + filter=filter, + resource_type="model", + permission=permission, + ) def search_models( self, @@ -358,7 +328,7 @@ class ModelsTable: for model, user in items: models.append( ModelUserResponse( - **ModelModel.model_validate(model).model_dump(), + **self._to_model_model(model, db=db).model_dump(), user=( UserResponse(**UserModel.model_validate(user).model_dump()) if user @@ -375,7 +345,7 @@ class ModelsTable: try: with get_db_context(db) as db: model = db.get(Model, id) - return ModelModel.model_validate(model) + return self._to_model_model(model, db=db) if model else None except Exception: return None @@ -385,7 +355,7 @@ class ModelsTable: try: with get_db_context(db) as db: models = db.query(Model).filter(Model.id.in_(ids)).all() - return [ModelModel.model_validate(model) for model in models] + return [self._to_model_model(model, db=db) for model in models] except Exception: return [] @@ -403,7 +373,7 @@ class ModelsTable: db.commit() db.refresh(model) - return ModelModel.model_validate(model) + return self._to_model_model(model, db=db) except Exception: return None @@ -413,14 +383,16 @@ class ModelsTable: try: with get_db_context(db) as db: # update only the fields that are present in the model - data = model.model_dump(exclude={"id"}) + data = model.model_dump(exclude={"id", "access_grants"}) result = db.query(Model).filter_by(id=id).update(data) db.commit() + if model.access_grants is not None: + AccessGrants.set_access_grants( + "model", id, model.access_grants, db=db + ) - model = db.get(Model, id) - db.refresh(model) - return ModelModel.model_validate(model) + return self.get_model_by_id(id, db=db) except Exception as e: log.exception(f"Failed to update the model by id {id}: {e}") return None @@ -428,6 +400,7 @@ class ModelsTable: def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("model", id, db=db) db.query(Model).filter_by(id=id).delete() db.commit() @@ -438,6 +411,9 @@ class ModelsTable: def delete_all_models(self, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + model_ids = [row[0] for row in db.query(Model.id).all()] + for model_id in model_ids: + AccessGrants.revoke_all_access("model", model_id, db=db) db.query(Model).delete() db.commit() @@ -462,7 +438,7 @@ class ModelsTable: if model.id in existing_ids: db.query(Model).filter_by(id=model.id).update( { - **model.model_dump(), + **model.model_dump(exclude={"access_grants"}), "user_id": user_id, "updated_at": int(time.time()), } @@ -470,22 +446,27 @@ class ModelsTable: else: new_model = Model( **{ - **model.model_dump(), + **model.model_dump(exclude={"access_grants"}), "user_id": user_id, "updated_at": int(time.time()), } ) db.add(new_model) + AccessGrants.set_access_grants( + "model", model.id, model.access_grants, db=db + ) # Remove models that are no longer present for model in existing_models: if model.id not in new_model_ids: + AccessGrants.revoke_all_access("model", model.id, db=db) db.delete(model) db.commit() return [ - ModelModel.model_validate(model) for model in db.query(Model).all() + self._to_model_model(model, db=db) + for model in db.query(Model).all() ] except Exception as e: log.exception(f"Error syncing models for user {user_id}: {e}") diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index bd2353078..d17c749d1 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -7,17 +7,13 @@ from functools import lru_cache from sqlalchemy.orm import Session from open_webui.internal.db import Base, get_db, get_db_context from open_webui.models.groups import Groups -from open_webui.utils.access_control import has_access from open_webui.models.users import User, UserModel, Users, UserResponse +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON -from sqlalchemy.dialects.postgresql import JSONB - - -from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func -from sqlalchemy.sql import exists +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import BigInteger, Column, Text, JSON +from sqlalchemy import or_, func, cast #################### # Note DB Schema @@ -34,8 +30,6 @@ class Note(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) - created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -50,7 +44,7 @@ class NoteModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -65,14 +59,14 @@ class NoteForm(BaseModel): title: str data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class NoteUpdateForm(BaseModel): title: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class NoteUserResponse(NoteModel): @@ -94,122 +88,25 @@ class NoteListResponse(BaseModel): class NoteTable: + def _get_access_grants( + self, note_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("note", note_id, db=db) + + def _to_note_model(self, note: Note, db: Optional[Session] = None) -> NoteModel: + note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"}) + note_data["access_grants"] = self._get_access_grants(note_data["id"], db=db) + return NoteModel.model_validate(note_data) + def _has_permission(self, db, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - dialect_name = db.bind.dialect.name - - conditions = [] - - # Handle read_only permission separately - if permission == "read_only": - # For read_only, we want items where: - # 1. User has explicit read permission (via groups or user-level) - # 2. BUT does NOT have write permission - # 3. Public items are NOT considered read_only - - read_conditions = [] - - # Group-level read permission - if group_ids: - group_read_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_read_conditions.append( - Note.access_control["read"]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_read_conditions.append( - cast( - Note.access_control["read"]["group_ids"], - JSONB, - ).contains([gid]) - ) - - if group_read_conditions: - read_conditions.append(or_(*group_read_conditions)) - - # Combine read conditions - if read_conditions: - has_read = or_(*read_conditions) - else: - # If no read conditions, return empty result - return query.filter(False) - - # Now exclude items where user has write permission - write_exclusions = [] - - # Exclude items owned by user (they have implicit write) - if user_id: - write_exclusions.append(Note.user_id != user_id) - - # Exclude items where user has explicit write permission via groups - if group_ids: - group_write_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_write_conditions.append( - Note.access_control["write"]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_write_conditions.append( - cast( - Note.access_control["write"]["group_ids"], - JSONB, - ).contains([gid]) - ) - - if group_write_conditions: - # User should NOT have write permission - write_exclusions.append(~or_(*group_write_conditions)) - - # Exclude public items (items without access_control) - write_exclusions.append(Note.access_control.isnot(None)) - write_exclusions.append(cast(Note.access_control, String) != "null") - - # Combine: has read AND does not have write AND not public - if write_exclusions: - query = query.filter(and_(has_read, *write_exclusions)) - else: - query = query.filter(has_read) - - return query - - # Original logic for other permissions (read, write, etc.) - # Public access conditions - if group_ids or user_id: - conditions.extend( - [ - Note.access_control.is_(None), - cast(Note.access_control, String) == "null", - ] - ) - - # User-level permission (owner has all permissions) - if user_id: - conditions.append(Note.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - Note.access_control[permission]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - Note.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query + return AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Note, + filter=filter, + resource_type="note", + permission=permission, + ) def insert_new_note( self, user_id: str, form_data: NoteForm, db: Optional[Session] = None @@ -219,17 +116,21 @@ class NoteTable: **{ "id": str(uuid.uuid4()), "user_id": user_id, - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "created_at": int(time.time_ns()), "updated_at": int(time.time_ns()), + "access_grants": [], } ) - new_note = Note(**note.model_dump()) + new_note = Note(**note.model_dump(exclude={"access_grants"})) db.add(new_note) db.commit() - return note + AccessGrants.set_access_grants( + "note", note.id, form_data.access_grants, db=db + ) + return self._to_note_model(new_note, db=db) def get_notes( self, skip: int = 0, limit: int = 50, db: Optional[Session] = None @@ -241,7 +142,7 @@ class NoteTable: if limit is not None: query = query.limit(limit) notes = query.all() - return [NoteModel.model_validate(note) for note in notes] + return [self._to_note_model(note, db=db) for note in notes] def search_notes( self, @@ -330,7 +231,7 @@ class NoteTable: for note, user in items: notes.append( NoteUserResponse( - **NoteModel.model_validate(note).model_dump(), + **self._to_note_model(note, db=db).model_dump(), user=( UserResponse(**UserModel.model_validate(user).model_dump()) if user @@ -365,14 +266,14 @@ class NoteTable: query = query.limit(limit) notes = query.all() - return [NoteModel.model_validate(note) for note in notes] + return [self._to_note_model(note, db=db) for note in notes] def get_note_by_id( self, id: str, db: Optional[Session] = None ) -> Optional[NoteModel]: with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() - return NoteModel.model_validate(note) if note else None + return self._to_note_model(note, db=db) if note else None def update_note_by_id( self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None @@ -391,17 +292,20 @@ class NoteTable: if "meta" in form_data: note.meta = {**note.meta, **form_data["meta"]} - if "access_control" in form_data: - note.access_control = form_data["access_control"] + if "access_grants" in form_data: + AccessGrants.set_access_grants( + "note", id, form_data["access_grants"], db=db + ) note.updated_at = int(time.time_ns()) db.commit() - return NoteModel.model_validate(note) if note else None + return self._to_note_model(note, db=db) if note else None def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("note", id, db=db) db.query(Note).filter(Note.id == id).delete() db.commit() return True diff --git a/backend/open_webui/models/prompt_history.py b/backend/open_webui/models/prompt_history.py index ea7f566fb..0f5e7cea8 100644 --- a/backend/open_webui/models/prompt_history.py +++ b/backend/open_webui/models/prompt_history.py @@ -45,6 +45,7 @@ class PromptHistoryModel(BaseModel): class PromptHistoryResponse(PromptHistoryModel): """Response model with user info.""" + user: Optional[UserResponse] = None @@ -91,16 +92,20 @@ class PromptHistoryTable: .limit(limit) .all() ) - + # Get user info for each entry user_ids = list(set(e.user_id for e in entries)) users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} - + return [ PromptHistoryResponse( **PromptHistoryModel.model_validate(entry).model_dump(), - user=users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None, + user=( + users_dict.get(entry.user_id).model_dump() + if users_dict.get(entry.user_id) + else None + ), ) for entry in entries ] @@ -112,7 +117,9 @@ class PromptHistoryTable: ) -> Optional[PromptHistoryModel]: """Get a specific history entry by ID.""" with get_db_context(db) as db: - entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first() + entry = ( + db.query(PromptHistory).filter(PromptHistory.id == history_id).first() + ) if entry: return PromptHistoryModel.model_validate(entry) return None @@ -155,27 +162,31 @@ class PromptHistoryTable: ) -> Optional[dict]: """Compute diff between two history entries.""" with get_db_context(db) as db: - from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first() + from_entry = ( + db.query(PromptHistory).filter(PromptHistory.id == from_id).first() + ) to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first() - + if not from_entry or not to_entry: return None - + from_snapshot = from_entry.snapshot to_snapshot = to_entry.snapshot - + # Compute diff for content field from_content = from_snapshot.get("content", "") to_content = to_snapshot.get("content", "") - - diff_lines = list(difflib.unified_diff( - from_content.splitlines(keepends=True), - to_content.splitlines(keepends=True), - fromfile=f"v{from_id[:8]}", - tofile=f"v{to_id[:8]}", - lineterm="", - )) - + + diff_lines = list( + difflib.unified_diff( + from_content.splitlines(keepends=True), + to_content.splitlines(keepends=True), + fromfile=f"v{from_id[:8]}", + tofile=f"v{to_id[:8]}", + lineterm="", + ) + ) + return { "from_id": from_id, "to_id": to_id, @@ -183,7 +194,6 @@ class PromptHistoryTable: "to_snapshot": to_snapshot, "content_diff": diff_lines, "name_changed": from_snapshot.get("name") != to_snapshot.get("name"), - "access_control_changed": from_snapshot.get("access_control") != to_snapshot.get("access_control"), } def delete_history_by_prompt_id( @@ -193,7 +203,9 @@ class PromptHistoryTable: ) -> bool: """Delete all history entries for a prompt.""" with get_db_context(db) as db: - db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete() + db.query(PromptHistory).filter( + PromptHistory.prompt_id == prompt_id + ).delete() db.commit() return True diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 4a85ba902..544aea767 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -7,15 +7,13 @@ from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse from open_webui.models.prompt_history import PromptHistories +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, func, cast -from open_webui.utils.access_control import has_access - - #################### # Prompts DB Schema #################### @@ -37,23 +35,6 @@ class Prompt(Base): created_at = Column(BigInteger, nullable=True) updated_at = Column(BigInteger, nullable=True) - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } - class PromptModel(BaseModel): id: Optional[str] = None @@ -68,7 +49,7 @@ class PromptModel(BaseModel): version_id: Optional[str] = None created_at: Optional[int] = None updated_at: Optional[int] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) model_config = ConfigDict(from_attributes=True) @@ -104,13 +85,27 @@ class PromptForm(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None tags: Optional[list[str]] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None version_id: Optional[str] = None # Active version commit_message: Optional[str] = None # For history tracking is_production: Optional[bool] = True # Whether to set new version as production class PromptsTable: + def _get_access_grants( + self, prompt_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db) + + def _to_prompt_model( + self, prompt: Prompt, db: Optional[Session] = None + ) -> PromptModel: + prompt_data = PromptModel.model_validate(prompt).model_dump( + exclude={"access_grants"} + ) + prompt_data["access_grants"] = self._get_access_grants(prompt_data["id"], db=db) + return PromptModel.model_validate(prompt_data) + def insert_new_prompt( self, user_id: str, form_data: PromptForm, db: Optional[Session] = None ) -> Optional[PromptModel]: @@ -126,7 +121,7 @@ class PromptsTable: data=form_data.data or {}, meta=form_data.meta or {}, tags=form_data.tags or [], - access_control=form_data.access_control, + access_grants=[], is_active=True, created_at=now, updated_at=now, @@ -134,12 +129,16 @@ class PromptsTable: try: with get_db_context(db) as db: - result = Prompt(**prompt.model_dump()) + result = Prompt(**prompt.model_dump(exclude={"access_grants"})) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "prompt", prompt_id, form_data.access_grants, db=db + ) if result: + current_access_grants = self._get_access_grants(prompt_id, db=db) snapshot = { "name": form_data.name, "content": form_data.content, @@ -147,7 +146,7 @@ class PromptsTable: "data": form_data.data or {}, "meta": form_data.meta or {}, "tags": form_data.tags or [], - "access_control": form_data.access_control, + "access_grants": [grant.model_dump() for grant in current_access_grants], } history_entry = PromptHistories.create_history_entry( @@ -165,7 +164,7 @@ class PromptsTable: db.commit() db.refresh(result) - return PromptModel.model_validate(result) + return self._to_prompt_model(result, db=db) else: return None except Exception: @@ -179,7 +178,7 @@ class PromptsTable: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(id=prompt_id).first() if prompt: - return PromptModel.model_validate(prompt) + return self._to_prompt_model(prompt, db=db) return None except Exception: return None @@ -191,7 +190,7 @@ class PromptsTable: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() if prompt: - return PromptModel.model_validate(prompt) + return self._to_prompt_model(prompt, db=db) return None except Exception: return None @@ -216,7 +215,7 @@ class PromptsTable: prompts.append( PromptUserResponse.model_validate( { - **PromptModel.model_validate(prompt).model_dump(), + **self._to_prompt_model(prompt, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -236,7 +235,14 @@ class PromptsTable: prompt for prompt in prompts if prompt.user_id == user_id - or has_access(user_id, permission, prompt.access_control, user_group_ids) + or AccessGrants.has_access( + user_id=user_id, + resource_type="prompt", + resource_id=prompt.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) ] def search_prompts( @@ -273,17 +279,15 @@ class PromptsTable: elif view_option == "shared": query = query.filter(Prompt.user_id != user_id) - # Apply access control filtering - group_ids = filter.get("group_ids", []) - filter_user_id = filter.get("user_id") - - if filter_user_id: - # User must have access: owner OR public OR explicit access - access_conditions = [ - Prompt.user_id == filter_user_id, # Owner - Prompt.access_control == None, # Public - ] - query = query.filter(or_(*access_conditions)) + # Apply access grant filtering + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Prompt, + filter=filter, + resource_type="prompt", + permission="read", + ) tag = filter.get("tag") if tag: @@ -329,7 +333,7 @@ class PromptsTable: for prompt, user in items: prompts.append( PromptUserResponse( - **PromptModel.model_validate(prompt).model_dump(), + **self._to_prompt_model(prompt, db=db).model_dump(), user=( UserResponse(**UserModel.model_validate(user).model_dump()) if user @@ -358,12 +362,13 @@ class PromptsTable: prompt.id, db=db ) parent_id = latest_history.id if latest_history else None + current_access_grants = self._get_access_grants(prompt.id, db=db) # Check if content changed to decide on history creation content_changed = ( prompt.name != form_data.name or prompt.content != form_data.content - or prompt.access_control != form_data.access_control + or form_data.access_grants is not None ) # Update prompt fields @@ -371,8 +376,12 @@ class PromptsTable: prompt.content = form_data.content prompt.data = form_data.data or prompt.data prompt.meta = form_data.meta or prompt.meta - prompt.access_control = form_data.access_control prompt.updated_at = int(time.time()) + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "prompt", prompt.id, form_data.access_grants, db=db + ) + current_access_grants = self._get_access_grants(prompt.id, db=db) db.commit() @@ -384,7 +393,9 @@ class PromptsTable: "command": command, "data": form_data.data or {}, "meta": form_data.meta or {}, - "access_control": form_data.access_control, + "access_grants": [ + grant.model_dump() for grant in current_access_grants + ], } history_entry = PromptHistories.create_history_entry( @@ -401,7 +412,7 @@ class PromptsTable: prompt.version_id = history_entry.id db.commit() - return PromptModel.model_validate(prompt) + return self._to_prompt_model(prompt, db=db) except Exception: return None @@ -422,13 +433,14 @@ class PromptsTable: prompt.id, db=db ) parent_id = latest_history.id if latest_history else None + current_access_grants = self._get_access_grants(prompt.id, db=db) # Check if content changed to decide on history creation content_changed = ( prompt.name != form_data.name or prompt.command != form_data.command or prompt.content != form_data.content - or prompt.access_control != form_data.access_control + or form_data.access_grants is not None or (form_data.tags is not None and prompt.tags != form_data.tags) ) @@ -438,10 +450,15 @@ class PromptsTable: prompt.content = form_data.content prompt.data = form_data.data or prompt.data prompt.meta = form_data.meta or prompt.meta - prompt.access_control = form_data.access_control if form_data.tags is not None: prompt.tags = form_data.tags + + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "prompt", prompt.id, form_data.access_grants, db=db + ) + current_access_grants = self._get_access_grants(prompt.id, db=db) prompt.updated_at = int(time.time()) @@ -456,7 +473,9 @@ class PromptsTable: "data": form_data.data or {}, "meta": form_data.meta or {}, "tags": prompt.tags or [], - "access_control": form_data.access_control, + "access_grants": [ + grant.model_dump() for grant in current_access_grants + ], } history_entry = PromptHistories.create_history_entry( @@ -473,7 +492,7 @@ class PromptsTable: prompt.version_id = history_entry.id db.commit() - return PromptModel.model_validate(prompt) + return self._to_prompt_model(prompt, db=db) except Exception: return None @@ -501,7 +520,7 @@ class PromptsTable: prompt.updated_at = int(time.time()) db.commit() - return PromptModel.model_validate(prompt) + return self._to_prompt_model(prompt, db=db) except Exception: return None @@ -533,13 +552,13 @@ class PromptsTable: prompt.data = snapshot.get("data", prompt.data) prompt.meta = snapshot.get("meta", prompt.meta) prompt.tags = snapshot.get("tags", prompt.tags) - # Note: command and access_control are not restored from snapshot + # Note: command and access_grants are not restored from snapshot prompt.version_id = version_id prompt.updated_at = int(time.time()) db.commit() - return PromptModel.model_validate(prompt) + return self._to_prompt_model(prompt, db=db) except Exception: return None @@ -552,6 +571,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(command=command).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) + AccessGrants.revoke_all_access("prompt", prompt.id, db=db) prompt.is_active = False prompt.updated_at = int(time.time()) @@ -568,6 +588,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(id=prompt_id).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) + AccessGrants.revoke_all_access("prompt", prompt.id, db=db) prompt.is_active = False prompt.updated_at = int(time.time()) @@ -586,6 +607,7 @@ class PromptsTable: prompt = db.query(Prompt).filter_by(command=command).first() if prompt: PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) + AccessGrants.revoke_all_access("prompt", prompt.id, db=db) # Delete prompt db.query(Prompt).filter_by(command=command).delete() diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index cd7d0bd1a..da439161e 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -6,11 +6,10 @@ from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import Users, UserResponse from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, JSON - -from open_webui.utils.access_control import has_access +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import BigInteger, Column, String, Text log = logging.getLogger(__name__) @@ -31,23 +30,6 @@ class Tool(Base): meta = Column(JSONField) valves = Column(JSONField) - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } - updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -64,7 +46,7 @@ class ToolModel(BaseModel): content: str specs: list[dict] meta: ToolMeta - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -86,7 +68,7 @@ class ToolResponse(BaseModel): user_id: str name: str meta: ToolMeta - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -106,7 +88,7 @@ class ToolForm(BaseModel): name: str content: str meta: ToolMeta - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class ToolValves(BaseModel): @@ -114,6 +96,16 @@ class ToolValves(BaseModel): class ToolsTable: + def _get_access_grants( + self, tool_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("tool", tool_id, db=db) + + def _to_tool_model(self, tool: Tool, db: Optional[Session] = None) -> ToolModel: + tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"}) + tool_data["access_grants"] = self._get_access_grants(tool_data["id"], db=db) + return ToolModel.model_validate(tool_data) + def insert_new_tool( self, user_id: str, @@ -122,23 +114,24 @@ class ToolsTable: db: Optional[Session] = None, ) -> Optional[ToolModel]: with get_db_context(db) as db: - tool = ToolModel( - **{ - **form_data.model_dump(), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), - } - ) - try: - result = Tool(**tool.model_dump()) + result = Tool( + **{ + **form_data.model_dump(exclude={"access_grants"}), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "tool", result.id, form_data.access_grants, db=db + ) if result: - return ToolModel.model_validate(result) + return self._to_tool_model(result, db=db) else: return None except Exception as e: @@ -151,7 +144,7 @@ class ToolsTable: try: with get_db_context(db) as db: tool = db.get(Tool, id) - return ToolModel.model_validate(tool) + return self._to_tool_model(tool, db=db) if tool else None except Exception: return None @@ -170,7 +163,7 @@ class ToolsTable: tools.append( ToolUserModel.model_validate( { - **ToolModel.model_validate(tool).model_dump(), + **self._to_tool_model(tool, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -189,7 +182,14 @@ class ToolsTable: tool for tool in tools if tool.user_id == user_id - or has_access(user_id, permission, tool.access_control, user_group_ids) + or AccessGrants.has_access( + user_id=user_id, + resource_type="tool", + resource_id=tool.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) ] def get_tool_valves_by_id( @@ -266,20 +266,24 @@ class ToolsTable: ) -> Optional[ToolModel]: try: with get_db_context(db) as db: + access_grants = updated.pop("access_grants", None) db.query(Tool).filter_by(id=id).update( {**updated, "updated_at": int(time.time())} ) db.commit() + if access_grants is not None: + AccessGrants.set_access_grants("tool", id, access_grants, db=db) tool = db.query(Tool).get(id) db.refresh(tool) - return ToolModel.model_validate(tool) + return self._to_tool_model(tool, db=db) except Exception: return None def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("tool", id, db=db) db.query(Tool).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 56315c73f..e2e1b8577 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -29,9 +29,9 @@ from open_webui.models.knowledge import Knowledges from open_webui.models.chats import Chats from open_webui.models.notes import Notes +from open_webui.models.access_grants import AccessGrants from open_webui.retrieval.vector.main import GetResult -from open_webui.utils.access_control import has_access from open_webui.utils.headers import include_user_info_headers from open_webui.utils.misc import get_message_list @@ -999,7 +999,12 @@ async def get_sources_from_items( if note and ( user.role == "admin" or note.user_id == user.id - or has_access(user.id, "read", note.access_control) + or AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): # User has access to the note query_result = { @@ -1091,7 +1096,12 @@ async def get_sources_from_items( if knowledge_base and ( user.role == "admin" or knowledge_base.user_id == user.id - or has_access(user.id, "read", knowledge_base.access_control) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + ) ): if ( item.get("context") == "full" @@ -1100,7 +1110,12 @@ async def get_sources_from_items( if knowledge_base and ( user.role == "admin" or knowledge_base.user_id == user.id - or has_access(user.id, "read", knowledge_base.access_control) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + ) ): files = Knowledges.get_files_by_id(knowledge_base.id) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 2c94ddd91..4713e369a 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -36,6 +36,7 @@ from open_webui.models.channels import ( ChannelWebhookModel, ChannelWebhookForm, ) +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.models.messages import ( Messages, MessageModel, @@ -60,12 +61,7 @@ from open_webui.utils.chat import generate_chat_completion from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import ( - has_access, - get_users_with_access, - get_permitted_group_and_user_ids, - has_permission, -) +from open_webui.utils.access_control import has_permission from open_webui.utils.webhook import post_webhook from open_webui.utils.channels import extract_mentions, replace_mentions from open_webui.internal.db import get_session @@ -76,6 +72,66 @@ log = logging.getLogger(__name__) router = APIRouter() +def channel_has_access( + user_id: str, + channel: ChannelModel, + permission: str = "read", + strict: bool = True, + db: Optional[Session] = None, +) -> bool: + if AccessGrants.has_access( + user_id=user_id, + resource_type="channel", + resource_id=channel.id, + permission=permission, + db=db, + ): + return True + + if ( + not strict + and permission == "write" + and has_public_read_access_grant(channel.access_grants) + ): + return True + + return False + + +def get_channel_users_with_access( + channel: ChannelModel, permission: str = "read", db: Optional[Session] = None +): + return AccessGrants.get_users_with_access( + resource_type="channel", + resource_id=channel.id, + permission=permission, + db=db, + ) + + +def get_channel_permitted_group_and_user_ids( + channel: ChannelModel, permission: str = "read" +) -> Optional[dict[str, list[str]]]: + if permission == "read" and has_public_read_access_grant(channel.access_grants): + return None + + user_ids = [] + group_ids = [] + + for grant in channel.access_grants: + if grant.permission != permission: + continue + if grant.principal_type == "group": + group_ids.append(grant.principal_id) + elif grant.principal_type == "user" and grant.principal_id != "*": + user_ids.append(grant.principal_id) + + return { + "user_ids": list(dict.fromkeys(user_ids)), + "group_ids": list(dict.fromkeys(group_ids)), + } + + ############################ # Channels Enabled Dependency ############################ @@ -418,22 +474,22 @@ async def get_channel_by_id( } ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - write_access = has_access( + write_access = channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ) - user_count = len(get_users_with_access("read", channel.access_control)) + user_count = len(get_channel_users_with_access(channel, "read", db=db)) channel_member = Channels.get_member_by_channel_and_user_id( channel.id, user.id, db=db @@ -527,8 +583,8 @@ async def get_channel_members_by_id( filter["channel_id"] = channel.id else: filter["roles"] = ["!pending"] - permitted_ids = get_permitted_group_and_user_ids( - "read", channel.access_control + permitted_ids = get_channel_permitted_group_and_user_ids( + channel, permission="read" ) if permitted_ids: filter["user_ids"] = permitted_ids.get("user_ids") @@ -811,8 +867,8 @@ async def get_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -888,8 +944,8 @@ async def get_pinned_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -946,7 +1002,7 @@ async def get_pinned_channel_messages( async def send_notification( name, webui_url, channel, message, active_user_ids, db=None ): - users = get_users_with_access("read", channel.access_control) + users = get_channel_users_with_access(channel, "read", db=db) for user in users: if (user.id not in active_user_ids) and Channels.is_user_channel_member( @@ -1173,10 +1229,10 @@ async def new_message_handler( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1318,8 +1374,8 @@ async def get_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1372,8 +1428,8 @@ async def get_channel_message_data( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1426,8 +1482,8 @@ async def pin_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1492,8 +1548,8 @@ async def get_channel_thread_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1577,8 +1633,8 @@ async def update_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + and not channel_has_access( + user.id, channel, permission="read", db=db ) ): raise HTTPException( @@ -1644,10 +1700,10 @@ async def add_reaction_to_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1723,10 +1779,10 @@ async def remove_reaction_by_id_and_user_id_and_name( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1818,10 +1874,10 @@ async def delete_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access( + and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 8b17dc406..4a12db5cd 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -38,6 +38,7 @@ from open_webui.models.files import ( from open_webui.models.chats import Chats from open_webui.models.knowledge import Knowledges from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants from open_webui.routers.retrieval import ProcessFileForm, process_file @@ -47,7 +48,6 @@ from open_webui.storage.provider import Storage from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access from open_webui.utils.misc import strict_match_mime_type from pydantic import BaseModel @@ -82,8 +82,13 @@ def has_access_to_file( group.id for group in Groups.get_groups_by_member_id(user.id, db=db) } for knowledge_base in knowledge_bases: - if knowledge_base.user_id == user.id or has_access( - user.id, access_type, knowledge_base.access_control, user_group_ids, db=db + if knowledge_base.user_id == user.id or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission=access_type, + user_group_ids=user_group_ids, + db=db, ): return True diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index cc0cb8f5a..a46e05473 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -7,6 +7,7 @@ from open_webui.models.users import Users, UserInfoResponse from open_webui.models.groups import ( Groups, GroupForm, + GroupInfoResponse, GroupUpdateForm, GroupResponse, UserIdsForm, @@ -104,6 +105,23 @@ async def get_group_by_id( ) +@router.get("/id/{id}/info", response_model=Optional[GroupInfoResponse]) +async def get_group_info_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) + if group: + return GroupInfoResponse( + **group.model_dump(), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # ExportGroupById ############################ diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 96136d689..8350fda03 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -29,7 +29,8 @@ from open_webui.storage.provider import Storage from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_verified_user, get_admin_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL @@ -133,8 +134,12 @@ async def get_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access( - user.id, "write", knowledge_base.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="write", + db=db, ) ), ) @@ -180,8 +185,12 @@ async def search_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access( - user.id, "write", knowledge_base.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="write", + db=db, ) ), ) @@ -243,14 +252,14 @@ async def create_new_knowledge( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, ) ): - form_data.access_control = {} + form_data.access_grants = [] knowledge = Knowledges.insert_new_knowledge(user.id, form_data) @@ -387,7 +396,13 @@ async def get_knowledge_by_id( if ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + db=db, + ) ): return KnowledgeFilesResponse( @@ -395,7 +410,13 @@ async def get_knowledge_by_id( write_access=( user.id == knowledge.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access(user.id, "write", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) ), ) else: @@ -435,7 +456,12 @@ async def update_knowledge_by_id( # Is the user the original creator, in a group with write access, or an admin if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + ) and user.role != "admin" ): raise HTTPException( @@ -446,14 +472,14 @@ async def update_knowledge_by_id( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, ) ): - form_data.access_control = {} + form_data.access_grants = [] knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: @@ -502,7 +528,13 @@ async def get_knowledge_files_by_id( if not ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -555,7 +587,13 @@ def add_file_to_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -624,7 +662,13 @@ def update_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): @@ -693,7 +737,13 @@ def remove_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -770,7 +820,13 @@ async def delete_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -802,7 +858,7 @@ async def delete_knowledge_by_id( base_model_id=model.base_model_id, meta=model.meta, params=model.params, - access_control=model.access_control, + access_grants=model.access_grants, is_active=model.is_active, ) Models.update_model_by_id(model.id, model_form, db=db) @@ -839,7 +895,13 @@ async def reset_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -882,7 +944,13 @@ async def add_files_to_knowledge_batch( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 85f0fb4f6..b2bccc195 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -15,6 +15,7 @@ from open_webui.models.models import ( ModelAccessResponse, Models, ) +from open_webui.models.access_grants import AccessGrants from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES @@ -30,7 +31,7 @@ from fastapi.responses import FileResponse, StreamingResponse from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -98,7 +99,13 @@ async def get_models( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ), ) for model in result.items @@ -315,14 +322,26 @@ async def get_model_by_id( if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id - or has_access(user.id, "read", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="read", + db=db, + ) ): return ModelAccessResponse( **model.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ), ) else: @@ -393,7 +412,13 @@ async def toggle_model_by_id( if ( user.role == "admin" or model.user_id == user.id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ): model = Models.toggle_model_by_id(id, db=db) @@ -436,7 +461,13 @@ async def update_model_by_id( if ( model.user_id != user.id - and not has_access(user.id, "write", model.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -471,7 +502,13 @@ async def delete_model_by_id( if ( user.role != "admin" and model.user_id != user.id - and not has_access(user.id, "write", model.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 56730e2b6..321b06fcd 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -27,7 +27,8 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -200,8 +201,12 @@ async def get_note_by_id( if user.role != "admin" and ( user.id != note.user_id and ( - not has_access( - user.id, type="read", access_control=note.access_control, db=db + not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + db=db, ) ) ): @@ -212,13 +217,14 @@ async def get_note_by_id( write_access = ( user.role == "admin" or (user.id == note.user_id) - or has_access( - user.id, - type="write", - access_control=note.access_control, - strict=False, + or AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", db=db, ) + or has_public_read_access_grant(note.access_grants) ) return NoteResponse(**note.model_dump(), write_access=write_access) @@ -253,8 +259,12 @@ async def update_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access( - user.id, type="write", access_control=note.access_control, db=db + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, ) ): raise HTTPException( @@ -264,7 +274,7 @@ async def update_note_by_id( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_notes", @@ -272,7 +282,7 @@ async def update_note_by_id( db=db, ) ): - form_data.access_control = {} + form_data.access_grants = [] try: note = Notes.update_note_by_id(id, form_data, db=db) @@ -318,8 +328,12 @@ async def delete_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access( - user.id, type="write", access_control=note.access_control, db=db + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, ) ): raise HTTPException( diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index cd14f4526..4f43c41d3 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -45,6 +45,7 @@ from open_webui.internal.db import get_session from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.utils.misc import ( calculate_sha256, ) @@ -54,9 +55,6 @@ from open_webui.utils.payload import ( apply_system_prompt_to_body, ) from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access - - from open_webui.config import ( UPLOAD_DIR, ) @@ -431,8 +429,12 @@ async def get_filtered_models(models, user, db=None): for model in models.get("models", []): model_info = Models.get_model_by_id(model["model"], db=db) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ): filtered_models.append(model) return filtered_models @@ -1293,7 +1295,7 @@ async def generate_chat_completion( raise HTTPException(status_code=503, detail="Ollama API is disabled") # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. if BYPASS_MODEL_ACCESS_CONTROL: @@ -1340,10 +1342,11 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1417,7 +1420,7 @@ async def generate_openai_completion( user=Depends(get_verified_user), ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. metadata = form_data.pop("metadata", None) @@ -1452,10 +1455,11 @@ async def generate_openai_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1502,7 +1506,7 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. metadata = form_data.pop("metadata", None) @@ -1541,10 +1545,11 @@ async def generate_openai_chat_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1642,10 +1647,11 @@ async def get_openai_models( for model in models: model_info = Models.get_model_by_id(model["id"], db=db) if model_info: - if user.id == model_info.user_id or has_access( - user.id, - type="read", - access_control=model_info.access_control, + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", db=db, ): filtered_models.append(model) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 5cb0deb69..d8ab50221 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -24,6 +24,7 @@ from sqlalchemy.orm import Session from open_webui.internal.db import get_session from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.config import ( CACHE_DIR, ) @@ -51,7 +52,6 @@ from open_webui.utils.misc import ( ) from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access from open_webui.utils.headers import include_user_info_headers @@ -463,8 +463,12 @@ async def get_filtered_models(models, user, db=None): for model in models.get("data", []): model_info = Models.get_model_by_id(model["id"], db=db) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ): filtered_models.append(model) return filtered_models @@ -907,7 +911,7 @@ async def generate_chat_completion( bypass_system_prompt: bool = False, ): # NOTE: We intentionally do NOT use Depends(get_session) here. - # Database operations (get_model_by_id, has_access) manage their own short-lived sessions. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. # This prevents holding a connection during the entire LLM call (30-60+ seconds), # which would exhaust the connection pool under concurrent load. if BYPASS_MODEL_ACCESS_CONTROL: @@ -945,10 +949,11 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index fc24ccaf4..77c2e84fc 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -9,6 +9,7 @@ from open_webui.models.prompts import ( PromptModel, Prompts, ) +from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups from open_webui.models.prompt_history import ( PromptHistories, @@ -17,7 +18,7 @@ from open_webui.models.prompt_history import ( ) from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -115,7 +116,13 @@ async def get_prompt_list( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) for prompt in result.items @@ -186,14 +193,26 @@ async def get_prompt_by_command( if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) @@ -218,14 +237,26 @@ async def get_prompt_by_id( if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) @@ -258,7 +289,13 @@ async def update_prompt_by_id( # Is the user the original creator, in a group with write access, or an admin if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -311,7 +348,13 @@ async def update_prompt_metadata( if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -356,7 +399,13 @@ async def set_prompt_version( if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -395,7 +444,13 @@ async def delete_prompt_by_id( if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -434,7 +489,13 @@ async def get_prompt_history( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -469,7 +530,13 @@ async def get_prompt_history_entry( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -508,7 +575,13 @@ async def delete_prompt_history_entry( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -553,7 +626,13 @@ async def get_prompt_diff( if not ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 7f9b23c7c..015bde232 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -21,6 +21,7 @@ from open_webui.models.tools import ( ToolAccessResponse, Tools, ) +from open_webui.models.access_grants import AccessGrants from open_webui.utils.plugin import ( load_tool_module_by_id, replace_imports, @@ -156,7 +157,24 @@ async def get_tools( tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control, user_group_ids, db=db) + or ( + has_access( + user.id, + "read", + getattr(tool, "access_control", None), + user_group_ids, + db=db, + ) + if str(tool.id).startswith("server:") + else AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="read", + user_group_ids=user_group_ids, + db=db, + ) + ) ] return tools @@ -181,7 +199,13 @@ async def get_tool_list( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tool.user_id - or has_access(user.id, "write", tool.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="write", + db=db, + ) ), ) for tool in tools @@ -382,14 +406,26 @@ async def get_tools_by_id( if ( user.role == "admin" or tools.user_id == user.id - or has_access(user.id, "read", tools.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="read", + db=db, + ) ): return ToolAccessResponse( **tools.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tools.user_id - or has_access(user.id, "write", tools.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) ), ) else: @@ -427,7 +463,13 @@ async def update_tools_by_id( # Is the user the original creator, in a group with write access, or an admin if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -489,7 +531,13 @@ async def delete_tools_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -588,7 +636,13 @@ async def update_tools_valves_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 20b69bcdf..5d8a1057d 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -19,7 +19,7 @@ from open_webui.models.users import ( UserModel, UserGroupIdsModel, UserGroupIdsListResponse, - UserInfoListResponse, + UserInfoResponse, UserInfoListResponse, UserRoleUpdateForm, UserStatus, @@ -446,7 +446,7 @@ class UserActiveResponse(UserStatus): @router.get("/{user_id}", response_model=UserActiveResponse) async def get_user_by_id( - user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) ): # Check if user_id is a shared chat # If it is, get the user_id from the chat @@ -478,6 +478,20 @@ async def get_user_by_id( ) +@router.get("/{user_id}/info", response_model=UserInfoResponse) +async def get_user_info_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user_id, db=db) + if user: + return user + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + @router.get("/{user_id}/oauth/sessions") async def get_user_oauth_sessions_by_id( user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 67e04e69c..e987f9c29 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -42,7 +42,7 @@ from open_webui.utils.auth import decode_token from open_webui.socket.utils import RedisDict, RedisLock, YdocManager from open_webui.tasks import create_task, stop_item_tasks from open_webui.utils.redis import get_redis_connection -from open_webui.utils.access_control import has_access, get_users_with_access +from open_webui.models.access_grants import AccessGrants from open_webui.env import ( @@ -405,7 +405,12 @@ async def join_note(sid, data): if ( user.role != "admin" and user.id != note.user_id - and not has_access(user.id, type="read", access_control=note.access_control) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): log.error(f"User {user.id} does not have access to note {data['note_id']}") return @@ -467,8 +472,11 @@ async def ydoc_document_join(sid, data): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( - user.get("id"), type="read", access_control=note.access_control + and not AccessGrants.has_access( + user_id=user.get("id"), + resource_type="note", + resource_id=note.id, + permission="read", ) ): log.error( @@ -537,8 +545,11 @@ async def document_save_handler(document_id, data, user): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( - user.get("id"), type="read", access_control=note.access_control + and not AccessGrants.has_access( + user_id=user.get("id"), + resource_type="note", + resource_id=note.id, + permission="read", ) ): log.error(f"User {user.get('id')} does not have access to note {note_id}") diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index 6aca621c9..9dd3ea1bc 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -743,10 +743,14 @@ async def view_note( user_id = __user__.get("id") user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants - if note.user_id != user_id and not has_access( - user_id, "read", note.access_control, user_group_ids + if note.user_id != user_id and not AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="read", + user_group_ids=set(user_group_ids), ): return json.dumps({"error": "Access denied"}) @@ -797,7 +801,7 @@ async def write_note( form = NoteForm( title=title, data={"content": {"md": content}}, - access_control={}, # Private by default - only owner can access + access_grants=[], # Private by default - only owner can access ) new_note = Notes.insert_new_note(user_id, form) @@ -852,10 +856,14 @@ async def replace_note_content( user_id = __user__.get("id") user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants - if note.user_id != user_id and not has_access( - user_id, "write", note.access_control, user_group_ids + if note.user_id != user_id and not AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="write", + user_group_ids=set(user_group_ids), ): return json.dumps({"error": "Write access denied"}) @@ -1532,7 +1540,7 @@ async def view_knowledge_file( try: from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants user_id = __user__.get("id") user_role = __user__.get("role", "user") @@ -1551,8 +1559,12 @@ async def view_knowledge_file( if ( user_role == "admin" or knowledge_base.user_id == user_id - or has_access( - user_id, "read", knowledge_base.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): has_knowledge_access = True @@ -1631,7 +1643,7 @@ async def query_knowledge_files( from open_webui.models.files import Files from open_webui.models.notes import Notes from open_webui.retrieval.utils import query_collection - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants user_id = __user__.get("id") user_role = __user__.get("role", "user") @@ -1656,8 +1668,12 @@ async def query_knowledge_files( if knowledge and ( user_role == "admin" or knowledge.user_id == user_id - or has_access( - user_id, "read", knowledge.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): collection_names.append(item_id) @@ -1674,7 +1690,12 @@ async def query_knowledge_files( if note and ( user_role == "admin" or note.user_id == user_id - or has_access(user_id, "read", note.access_control) + or AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): content = note.data.get("content", {}).get("md", "") note_results.append( @@ -1693,8 +1714,12 @@ async def query_knowledge_files( if knowledge and ( user_role == "admin" or knowledge.user_id == user_id - or has_access( - user_id, "read", knowledge.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): collection_names.append(knowledge_id) diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index b3a332ade..4224605f1 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -13,6 +13,7 @@ from open_webui.functions import get_function_models from open_webui.models.functions import Functions from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups @@ -354,8 +355,12 @@ def check_model_access(user, model, db=None): raise Exception("Model not found") elif not ( user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ) ): raise Exception("Model not found") @@ -395,11 +400,13 @@ def get_filtered_models(models, user, db=None): if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", user_group_ids=user_group_ids, + db=db, ) ): filtered_models.append(model) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index ec281b43a..5bb523f83 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -38,6 +38,7 @@ from open_webui.utils.misc import is_string_allowed from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants from open_webui.utils.plugin import load_tool_module_by_id from open_webui.utils.access_control import has_access from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL @@ -168,7 +169,13 @@ async def get_tools( if ( not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) and tool.user_id != user.id - and not has_access(user.id, "read", tool.access_control, user_group_ids) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="read", + user_group_ids=user_group_ids, + ) ): log.warning(f"Access denied to tool {tool_id} for user {user.id}") continue diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 225d8cd7c..5715c64e8 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -3,10 +3,11 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; type ChannelForm = { type?: string; name: string; - is_private?: boolean; + is_private?: boolean | null; data?: object; meta?: object; - access_control?: object; + access_grants?: object[]; + group_ids?: string[]; user_ids?: string[]; }; diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts index a74c61b83..6089a6023 100644 --- a/src/lib/apis/groups/index.ts +++ b/src/lib/apis/groups/index.ts @@ -99,6 +99,38 @@ export const getGroupById = async (token: string, id: string) => { return res; }; +export const getGroupInfoById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/info`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateGroupById = async (token: string, id: string, group: object) => { let error = null; diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index dc9dd8b88..4c7c90484 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -4,7 +4,7 @@ export const createNewKnowledge = async ( token: string, name: string, description: string, - accessControl: null | object + accessGrants: object[] ) => { let error = null; @@ -18,7 +18,7 @@ export const createNewKnowledge = async ( body: JSON.stringify({ name: name, description: description, - access_control: accessControl + access_grants: accessGrants }) }) .then(async (res) => { @@ -248,7 +248,7 @@ type KnowledgeUpdateForm = { name?: string; description?: string; data?: object; - access_control?: null | object; + access_grants?: object[]; }; export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { @@ -265,7 +265,7 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl name: form?.name ? form.name : undefined, description: form?.description ? form.description : undefined, data: form?.data ? form.data : undefined, - access_control: form.access_control + access_grants: form.access_grants }) }) .then(async (res) => { diff --git a/src/lib/apis/notes/index.ts b/src/lib/apis/notes/index.ts index 55f9427e0..341ced57e 100644 --- a/src/lib/apis/notes/index.ts +++ b/src/lib/apis/notes/index.ts @@ -5,7 +5,7 @@ type NoteItem = { title: string; data: object; meta?: null | object; - access_control?: null | object; + access_grants?: object[]; }; export const createNewNote = async (token: string, note: NoteItem) => { diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index e9cd6e848..c227c9f71 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -7,7 +7,7 @@ type PromptItem = { content: string; data?: object | null; meta?: object | null; - access_control?: null | object; + access_grants?: object[]; version_id?: string | null; // Active version commit_message?: string | null; // For history tracking is_production?: boolean; // Whether to set new version as production @@ -23,7 +23,7 @@ type PromptHistoryItem = { command: string; data: object; meta: object; - access_control: object | null; + access_grants: object[]; }; user_id: string; commit_message: string | null; @@ -42,7 +42,7 @@ type PromptDiff = { to_snapshot: object; content_diff: string[]; name_changed: boolean; - access_control_changed: boolean; + access_grants_changed: boolean; }; export const createNewPrompt = async (token: string, prompt: PromptItem) => { @@ -611,4 +611,3 @@ export const getPromptDiff = async ( return res; }; - diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index d6da54bbf..cd3b40adc 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -327,6 +327,33 @@ export const getUserById = async (token: string, userId: string) => { return res; }; +export const getUserInfoById = async (token: string, userId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/${userId}/info`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.error(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateUserStatus = async (token: string, formData: object) => { let error = null; diff --git a/src/lib/components/channel/ChannelInfoModal.svelte b/src/lib/components/channel/ChannelInfoModal.svelte index 44094f780..cd1ee3524 100644 --- a/src/lib/components/channel/ChannelInfoModal.svelte +++ b/src/lib/components/channel/ChannelInfoModal.svelte @@ -15,13 +15,32 @@ import AddMembersModal from './ChannelInfoModal/AddMembersModal.svelte'; export let show = false; - export let channel = null; + export let channel: any = null; export let onUpdate = () => {}; let showAddMembersModal = false; const submitHandler = async () => {}; + const hasPublicReadGrant = (grants: any) => + Array.isArray(grants) && + grants.some( + (grant) => + grant?.principal_type === 'user' && + grant?.principal_id === '*' && + grant?.permission === 'read' + ); + + const isPublicChannel = (channel: any): boolean => { + if (channel?.type === 'group') { + if (typeof channel?.is_private === 'boolean') { + return !channel.is_private; + } + return hasPublicReadGrant(channel?.access_grants); + } + return hasPublicReadGrant(channel?.access_grants); + }; + const removeMemberHandler = async (userId) => { const res = await removeMembersById(localStorage.token, channel.id, { user_ids: [userId] @@ -62,7 +81,7 @@ {:else}