Initial commit, user-company segregation

This commit is contained in:
Phil Szalay 2025-02-10 14:16:23 +01:00
parent e9d6ada25c
commit 07a6fbbb5f
32 changed files with 316 additions and 53 deletions

View File

@ -0,0 +1,138 @@
import json
from pydantic import BaseModel, ConfigDict
from typing import Optional
from sqlalchemy.orm import relationship
from sqlalchemy import String, Column, Text
from open_webui.internal.db import get_db, Base
####################
# Company DB Schema
####################
class Company(Base):
__tablename__ = "company"
id = Column(String, primary_key=True, unique=True)
name = Column(String, nullable=False)
profile_image_url = Column(Text, nullable=True)
default_model = Column(String, default="GPT 4o")
allowed_models = Column(Text, nullable=True)
users = relationship("User", back_populates="company", cascade="all, delete-orphan")
class CompanyModel(BaseModel):
id: str
name: str
profile_image_url: Optional[str] = None
default_model: Optional[str] = "GPT 4o"
allowed_models: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class CompanyModelForm(BaseModel):
id: str
model_id: str
class CompanyForm(BaseModel):
company: dict
class CompanyResponse(BaseModel):
id: str
name: str
profile_image_url: str
default_model: Optional[str] = "GPT 4o"
allowed_models: Optional[str]
class CompanyTable:
def get_company_by_id(self, company_id: str):
try:
with get_db() as db:
company = db.query(Company).filter_by(id=company_id).first()
return CompanyModel.model_validate(company)
except Exception as e:
print(f"Error getting company: {e}")
return None
def update_company_by_id(self, id: str, updated: dict) -> Optional[CompanyModel]:
try:
with get_db() as db:
db.query(Company).filter_by(id=id).update(updated)
db.commit()
company = db.query(Company).filter_by(id=id).first()
return CompanyModel.model_validate(company)
except Exception as e:
print(f"Error updating company", e)
return None
def add_model(self, company_id: str, model_id: str) -> bool:
try:
with get_db() as db:
# Fetch the company by its ID
company = db.query(Company).filter_by(id=company_id).first()
print("Company: ", company.allowed_models)
# If company doesn't exist, return False
if not company:
return None
company.allowed_models = '[]' if company.allowed_models is None else company.allowed_models
# Load current members from JSON
current_models = json.loads(company.allowed_models)
# If model_id is not already in the list, add it
if model_id not in current_models:
current_models.append(model_id)
payload = {"allowed_models": json.dumps(current_models)}
db.query(Company).filter_by(id=company_id).update(payload)
db.commit()
return True
else:
# Model already exists in the company
return False
except Exception as e:
# Handle exceptions if any
print("ERRRO::: ", e)
return False
def remove_model(self, company_id: str, model_id: str) -> bool:
try:
with get_db() as db:
# Fetch the company by its ID
company = db.query(Company).filter_by(id=company_id).first()
# If company doesn't exist, return False
if not company:
return None
# Load current members from JSON
current_models = json.loads(company.allowed_models)
# If model_id is in the list, remove it
if model_id in current_models:
current_models.remove(model_id)
payload = {"allowed_models": json.dumps(current_models)}
db.query(Company).filter_by(id=company_id).update(payload)
db.commit()
return True
else:
# Member not found in the company
return False
except Exception as e:
# Handle exceptions if any
return False
Companies = CompanyTable()

View File

@ -2,10 +2,13 @@ import logging
import time
from typing import Optional
from .companies import Companies
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.users import Users, UserResponse
from beyond_the_loop.models.users import Users
from beyond_the_loop.models.companies import CompanyResponse
from pydantic import BaseModel, ConfigDict
@ -78,6 +81,8 @@ class Model(Base):
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
company_id = Column(Text, nullable=False)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
@ -116,6 +121,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
company_id: str
model_config = ConfigDict(from_attributes=True)
@ -124,8 +131,8 @@ class ModelModel(BaseModel):
####################
class ModelUserResponse(ModelModel):
user: Optional[UserResponse] = None
class ModelCompanyResponse(ModelModel):
company: Optional[CompanyResponse] = None
class ModelResponse(ModelModel):
@ -144,12 +151,13 @@ class ModelForm(BaseModel):
class ModelsTable:
def insert_new_model(
self, form_data: ModelForm, user_id: str
self, form_data: ModelForm, user_id: str, company_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"company_id": company_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
@ -173,16 +181,16 @@ class ModelsTable:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelUserResponse]:
def get_models(self) -> list[ModelCompanyResponse]:
with get_db() as db:
models = []
for model in db.query(Model).filter(Model.base_model_id != None).all():
user = Users.get_user_by_id(model.user_id)
company = Companies.get_company_by_id(model.company_id)
models.append(
ModelUserResponse.model_validate(
ModelCompanyResponse.model_validate(
{
**ModelModel.model_validate(model).model_dump(),
"user": user.model_dump() if user else None,
"company": company.model_dump() if company else None,
}
)
)
@ -195,15 +203,15 @@ class ModelsTable:
for model in db.query(Model).filter(Model.base_model_id == None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
def get_models_by_company_id(
self, company_id: str, permission: str = "write"
) -> list[ModelCompanyResponse]:
models = self.get_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
if model.company_id == company_id
or has_access(company_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:

View File

@ -2,14 +2,18 @@ import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
from functools import partial
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, ForeignKey
from sqlalchemy.orm import relationship
# TODO: I think this can be removed after we use companies somewhere else in the code
from .companies import Company
####################
# User DB Schema
@ -35,9 +39,12 @@ class User(Base):
oauth_sub = Column(Text, unique=True)
company_id = Column(String, ForeignKey("company.id", ondelete="CASCADE"), nullable=False)
company = relationship("Company", back_populates="users")
class UserSettings(BaseModel):
ui: Optional[dict] = {}
ui: Optional[dict] = partial(dict)
model_config = ConfigDict(extra="allow")
pass
@ -61,6 +68,7 @@ class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
company_id: str
####################
# Forms
@ -100,9 +108,10 @@ class UsersTable:
id: str,
name: str,
email: str,
company_id: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
oauth_sub: Optional[str] = None
) -> Optional[UserModel]:
with get_db() as db:
user = UserModel(
@ -116,6 +125,7 @@ class UsersTable:
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
"company_id": company_id
}
)
result = User(**user.model_dump())

View File

@ -1 +0,0 @@
docker dir for backend files (db, documents, etc.)

View File

@ -25,7 +25,7 @@ from open_webui.socket.main import (
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from beyond_the_loop.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.tools import get_tools

View File

@ -84,8 +84,8 @@ from open_webui.routers.retrieval import (
from open_webui.internal.db import Session
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users
from beyond_the_loop.models.models import Models
from beyond_the_loop.models.users import UserModel, Users
from open_webui.config import (
# Ollama

View File

@ -0,0 +1,36 @@
"""Add company table
Revision ID: 65ce764b8f7e
Revises: 3781e22d8b01
Create Date: 2025-02-10 12:34:00.054493
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = '65ce764b8f7e'
down_revision: Union[str, None] = '3781e22d8b01'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'company',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('profile_image_url', sa.Text(), nullable=True),
sa.Column('default_model', sa.String(), server_default='GPT 4o'),
sa.Column('allowed_models', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
)
def downgrade() -> None:
op.drop_table('company')

View File

@ -0,0 +1,60 @@
"""Add company_id to user table
Revision ID: 9ca43b058511
Revises: 65ce764b8f7e
Create Date: 2025-02-10 13:11:02.691437
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision: str = '9ca43b058511'
down_revision: Union[str, None] = '65ce764b8f7e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def column_exists(table, column):
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
columns = inspector.get_columns(table)
return any(c["name"] == column for c in columns)
def upgrade() -> None:
# Add company_id column if it doesn't exist
if not column_exists("user", "company_id"):
op.add_column('user', sa.Column('company_id', sa.String(), nullable=True))
# Make columns non-nullable
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.alter_column('company_id', nullable=False)
# Add foreign key constraint if it doesn't exist
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
foreign_keys = inspector.get_foreign_keys('user')
if not any(fk['referred_table'] == 'company' for fk in foreign_keys):
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.create_foreign_key(
'fk_user_company_id', 'company',
['company_id'], ['id'], ondelete='CASCADE'
)
def downgrade() -> None:
with op.batch_alter_table('user', schema=None) as batch_op:
# Remove foreign key constraint if it exists
foreign_keys = Inspector.from_engine(op.get_bind()).get_foreign_keys('user')
if any(fk['name'] == 'fk_user_company_id' for fk in foreign_keys):
batch_op.drop_constraint('fk_user_company_id', type_='foreignkey')
# Drop columns if they exist
if column_exists("user", "company_id"):
batch_op.drop_column('company_id')

View File

@ -3,7 +3,7 @@ import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, Users
from beyond_the_loop.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
@ -91,7 +91,8 @@ class SignupForm(BaseModel):
class AddUserForm(SignupForm):
role: Optional[str] = "pending"
role: Optional[str] = "pending",
company_id: str
class AuthsTable:
@ -100,6 +101,7 @@ class AuthsTable:
email: str,
password: str,
name: str,
company_id: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
@ -116,7 +118,7 @@ class AuthsTable:
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
id, name, email, company_id, profile_image_url, role, oauth_sub
)
db.commit()

View File

@ -3,7 +3,7 @@ import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users
from beyond_the_loop.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text

View File

@ -8,7 +8,7 @@ from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse
from open_webui.models.users import Users, UserResponse
from beyond_the_loop.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict

View File

@ -2,7 +2,7 @@ import time
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.users import Users, UserResponse
from beyond_the_loop.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -3,7 +3,7 @@ import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse
from beyond_the_loop.models.users import Users, UserResponse
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -15,7 +15,7 @@ from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.models.users import UserModel
from beyond_the_loop.models.users import UserModel
from open_webui.env import (
SRC_LOG_LEVELS,

View File

@ -18,7 +18,7 @@ from open_webui.models.auths import (
UpdateProfileForm,
UserResponse,
)
from open_webui.models.users import Users
from beyond_the_loop.models.users import Users
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
@ -258,7 +258,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
)
user = Auths.insert_new_auth(
email=mail, password=str(uuid.uuid4()), name=cn, role=role
email=mail, password=str(uuid.uuid4()), name=cn, company_id="NO_COMPANY", role=role
)
if not user:
@ -451,6 +451,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
form_data.email.lower(),
hashed,
form_data.name,
"NO_COMPANY",
form_data.profile_image_url,
role,
)
@ -564,6 +565,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
form_data.email.lower(),
hashed,
form_data.name,
form_data.company_id,
form_data.profile_image_url,
form_data.role,
)

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel
from open_webui.socket.main import sio, get_user_ids_from_room
from open_webui.models.users import Users, UserNameResponse
from beyond_the_loop.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
from open_webui.models.messages import (

View File

@ -2,7 +2,7 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from open_webui.models.users import Users, UserModel
from beyond_the_loop.models.users import Users, UserModel
from open_webui.models.feedbacks import (
FeedbackModel,
FeedbackResponse,

View File

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional
from open_webui.models.users import Users
from beyond_the_loop.models.users import Users
from open_webui.models.groups import (
Groups,
GroupForm,

View File

@ -25,7 +25,7 @@ from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.models import Models, ModelForm
from beyond_the_loop.models.models import Models, ModelForm
log = logging.getLogger(__name__)

View File

@ -1,10 +1,10 @@
from typing import Optional
from open_webui.models.models import (
from beyond_the_loop.models.models import (
ModelForm,
ModelModel,
ModelResponse,
ModelUserResponse,
ModelCompanyResponse,
Models,
)
from open_webui.constants import ERROR_MESSAGES
@ -23,7 +23,7 @@ router = APIRouter()
###########################
@router.get("/", response_model=list[ModelUserResponse])
@router.get("/", response_model=list[ModelCompanyResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if user.role == "admin":
return Models.get_models()

View File

@ -32,7 +32,7 @@ from pydantic import BaseModel, ConfigDict
from starlette.background import BackgroundTask
from open_webui.models.models import Models
from beyond_the_loop.models.models import Models
from open_webui.utils.misc import (
calculate_sha256,
)

View File

@ -16,7 +16,7 @@ from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.models.models import Models
from beyond_the_loop.models.models import Models
from open_webui.config import (
CACHE_DIR,
)

View File

@ -3,7 +3,7 @@ from typing import Optional
from open_webui.models.auths import Auths
from open_webui.models.chats import Chats
from open_webui.models.users import (
from beyond_the_loop.models.users import (
UserModel,
UserRoleUpdateForm,
Users,

View File

@ -4,7 +4,7 @@ import logging
import sys
import time
from open_webui.models.users import Users, UserNameResponse
from beyond_the_loop.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats

View File

@ -32,6 +32,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="user",
)
@ -53,6 +54,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="user",
)
@ -80,6 +82,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password=get_password_hash("password"),
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="user",
)
@ -142,6 +145,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password="password",
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="admin",
)
@ -159,6 +163,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password="password",
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="admin",
)
@ -174,6 +179,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password="password",
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="admin",
)
@ -190,6 +196,7 @@ class TestAuths(AbstractPostgresTest):
email="john.doe@openwebui.com",
password="password",
name="John Doe",
company_id="1",
profile_image_url="/user.png",
role="admin",
)

View File

@ -1,5 +1,5 @@
from typing import Optional, Union, List, Dict, Any
from open_webui.models.users import Users, UserModel
from beyond_the_loop.models.users import Users, UserModel
from open_webui.models.groups import Groups

View File

@ -5,7 +5,7 @@ import jwt
from datetime import UTC, datetime, timedelta
from typing import Optional, Union, List, Dict
from open_webui.models.users import Users
from beyond_the_loop.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY

View File

@ -12,7 +12,7 @@ from fastapi import Request
from starlette.responses import Response, StreamingResponse
from open_webui.models.users import UserModel
from beyond_the_loop.models.users import UserModel
from open_webui.socket.main import (
get_event_call,
@ -34,7 +34,7 @@ from open_webui.routers.pipelines import (
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from beyond_the_loop.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id

View File

@ -25,7 +25,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats
from open_webui.models.users import Users
from beyond_the_loop.models.users import Users
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
@ -44,9 +44,9 @@ from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.utils.webhook import post_webhook
from open_webui.models.users import UserModel
from beyond_the_loop.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from beyond_the_loop.models.models import Models
from open_webui.retrieval.utils import get_sources_from_files

View File

@ -10,7 +10,7 @@ from open_webui.functions import get_function_models
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from beyond_the_loop.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id

View File

@ -13,7 +13,7 @@ from fastapi import (
from starlette.responses import RedirectResponse
from open_webui.models.auths import Auths
from open_webui.models.users import Users
from beyond_the_loop.models.users import Users
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm
from open_webui.config import (
DEFAULT_USER_ROLE,
@ -292,6 +292,7 @@ class OAuthManager:
str(uuid.uuid4())
), # Random password, not used
name=name,
company_id="NO_COMPANY",
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,

View File

@ -11,7 +11,7 @@ from langchain_core.utils.function_calling import convert_to_openai_function
from open_webui.models.tools import Tools
from open_webui.models.users import UserModel
from beyond_the_loop.models.users import UserModel
from open_webui.utils.plugin import load_tools_module_by_id
log = logging.getLogger(__name__)