mirror of
https://github.com/open-webui/open-webui
synced 2025-04-07 22:25:05 +00:00
Initial commit, user-company segregation
This commit is contained in:
parent
e9d6ada25c
commit
07a6fbbb5f
138
backend/beyond_the_loop/models/companies.py
Normal file
138
backend/beyond_the_loop/models/companies.py
Normal file
@ -0,0 +1,138 @@
|
||||
import json
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import String, Column, Text
|
||||
|
||||
from open_webui.internal.db import get_db, Base
|
||||
|
||||
####################
|
||||
# Company DB Schema
|
||||
####################
|
||||
|
||||
class Company(Base):
|
||||
__tablename__ = "company"
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
name = Column(String, nullable=False)
|
||||
profile_image_url = Column(Text, nullable=True)
|
||||
default_model = Column(String, default="GPT 4o")
|
||||
allowed_models = Column(Text, nullable=True)
|
||||
users = relationship("User", back_populates="company", cascade="all, delete-orphan")
|
||||
|
||||
class CompanyModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
profile_image_url: Optional[str] = None
|
||||
default_model: Optional[str] = "GPT 4o"
|
||||
allowed_models: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
class CompanyModelForm(BaseModel):
|
||||
id: str
|
||||
model_id: str
|
||||
|
||||
class CompanyForm(BaseModel):
|
||||
company: dict
|
||||
|
||||
|
||||
class CompanyResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
profile_image_url: str
|
||||
default_model: Optional[str] = "GPT 4o"
|
||||
allowed_models: Optional[str]
|
||||
|
||||
|
||||
class CompanyTable:
|
||||
def get_company_by_id(self, company_id: str):
|
||||
try:
|
||||
with get_db() as db:
|
||||
company = db.query(Company).filter_by(id=company_id).first()
|
||||
return CompanyModel.model_validate(company)
|
||||
except Exception as e:
|
||||
print(f"Error getting company: {e}")
|
||||
return None
|
||||
|
||||
def update_company_by_id(self, id: str, updated: dict) -> Optional[CompanyModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Company).filter_by(id=id).update(updated)
|
||||
db.commit()
|
||||
|
||||
company = db.query(Company).filter_by(id=id).first()
|
||||
return CompanyModel.model_validate(company)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating company", e)
|
||||
return None
|
||||
|
||||
|
||||
def add_model(self, company_id: str, model_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Fetch the company by its ID
|
||||
company = db.query(Company).filter_by(id=company_id).first()
|
||||
print("Company: ", company.allowed_models)
|
||||
# If company doesn't exist, return False
|
||||
if not company:
|
||||
return None
|
||||
|
||||
company.allowed_models = '[]' if company.allowed_models is None else company.allowed_models
|
||||
# Load current members from JSON
|
||||
current_models = json.loads(company.allowed_models)
|
||||
|
||||
# If model_id is not already in the list, add it
|
||||
if model_id not in current_models:
|
||||
current_models.append(model_id)
|
||||
|
||||
payload = {"allowed_models": json.dumps(current_models)}
|
||||
db.query(Company).filter_by(id=company_id).update(payload)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
else:
|
||||
# Model already exists in the company
|
||||
return False
|
||||
except Exception as e:
|
||||
# Handle exceptions if any
|
||||
print("ERRRO::: ", e)
|
||||
return False
|
||||
|
||||
def remove_model(self, company_id: str, model_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Fetch the company by its ID
|
||||
company = db.query(Company).filter_by(id=company_id).first()
|
||||
|
||||
# If company doesn't exist, return False
|
||||
if not company:
|
||||
return None
|
||||
|
||||
# Load current members from JSON
|
||||
current_models = json.loads(company.allowed_models)
|
||||
|
||||
# If model_id is in the list, remove it
|
||||
if model_id in current_models:
|
||||
current_models.remove(model_id)
|
||||
|
||||
payload = {"allowed_models": json.dumps(current_models)}
|
||||
db.query(Company).filter_by(id=company_id).update(payload)
|
||||
db.commit()
|
||||
return True
|
||||
else:
|
||||
# Member not found in the company
|
||||
return False
|
||||
except Exception as e:
|
||||
# Handle exceptions if any
|
||||
return False
|
||||
|
||||
|
||||
|
||||
Companies = CompanyTable()
|
@ -2,10 +2,13 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from .companies import Companies
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from beyond_the_loop.models.users import Users
|
||||
|
||||
from beyond_the_loop.models.companies import CompanyResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -78,6 +81,8 @@ class Model(Base):
|
||||
Holds a JSON encoded blob of metadata, see `ModelMeta`.
|
||||
"""
|
||||
|
||||
company_id = Column(Text, nullable=False)
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
# - `None`: Public access, available to all users with the "user" role.
|
||||
@ -116,6 +121,8 @@ class ModelModel(BaseModel):
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
company_id: str
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@ -124,8 +131,8 @@ class ModelModel(BaseModel):
|
||||
####################
|
||||
|
||||
|
||||
class ModelUserResponse(ModelModel):
|
||||
user: Optional[UserResponse] = None
|
||||
class ModelCompanyResponse(ModelModel):
|
||||
company: Optional[CompanyResponse] = None
|
||||
|
||||
|
||||
class ModelResponse(ModelModel):
|
||||
@ -144,12 +151,13 @@ class ModelForm(BaseModel):
|
||||
|
||||
class ModelsTable:
|
||||
def insert_new_model(
|
||||
self, form_data: ModelForm, user_id: str
|
||||
self, form_data: ModelForm, user_id: str, company_id: str
|
||||
) -> Optional[ModelModel]:
|
||||
model = ModelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"company_id": company_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
@ -173,16 +181,16 @@ class ModelsTable:
|
||||
with get_db() as db:
|
||||
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
||||
|
||||
def get_models(self) -> list[ModelUserResponse]:
|
||||
def get_models(self) -> list[ModelCompanyResponse]:
|
||||
with get_db() as db:
|
||||
models = []
|
||||
for model in db.query(Model).filter(Model.base_model_id != None).all():
|
||||
user = Users.get_user_by_id(model.user_id)
|
||||
company = Companies.get_company_by_id(model.company_id)
|
||||
models.append(
|
||||
ModelUserResponse.model_validate(
|
||||
ModelCompanyResponse.model_validate(
|
||||
{
|
||||
**ModelModel.model_validate(model).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
"company": company.model_dump() if company else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -195,15 +203,15 @@ class ModelsTable:
|
||||
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
||||
]
|
||||
|
||||
def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ModelUserResponse]:
|
||||
def get_models_by_company_id(
|
||||
self, company_id: str, permission: str = "write"
|
||||
) -> list[ModelCompanyResponse]:
|
||||
models = self.get_models()
|
||||
return [
|
||||
model
|
||||
for model in models
|
||||
if model.user_id == user_id
|
||||
or has_access(user_id, permission, model.access_control)
|
||||
if model.company_id == company_id
|
||||
or has_access(company_id, permission, model.access_control)
|
||||
]
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
@ -2,14 +2,18 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
from functools import partial
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Column, String, Text, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# TODO: I think this can be removed after we use companies somewhere else in the code
|
||||
from .companies import Company
|
||||
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
@ -35,9 +39,12 @@ class User(Base):
|
||||
|
||||
oauth_sub = Column(Text, unique=True)
|
||||
|
||||
company_id = Column(String, ForeignKey("company.id", ondelete="CASCADE"), nullable=False)
|
||||
company = relationship("Company", back_populates="users")
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
ui: Optional[dict] = partial(dict)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
@ -61,6 +68,7 @@ class UserModel(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
company_id: str
|
||||
|
||||
####################
|
||||
# Forms
|
||||
@ -100,9 +108,10 @@ class UsersTable:
|
||||
id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
company_id: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
oauth_sub: Optional[str] = None
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
user = UserModel(
|
||||
@ -116,6 +125,7 @@ class UsersTable:
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"oauth_sub": oauth_sub,
|
||||
"company_id": company_id
|
||||
}
|
||||
)
|
||||
result = User(**user.model_dump())
|
@ -1 +0,0 @@
|
||||
docker dir for backend files (db, documents, etc.)
|
@ -25,7 +25,7 @@ from open_webui.socket.main import (
|
||||
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from beyond_the_loop.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.tools import get_tools
|
||||
|
@ -84,8 +84,8 @@ from open_webui.routers.retrieval import (
|
||||
from open_webui.internal.db import Session
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.models.users import UserModel, Users
|
||||
from beyond_the_loop.models.models import Models
|
||||
from beyond_the_loop.models.users import UserModel, Users
|
||||
|
||||
from open_webui.config import (
|
||||
# Ollama
|
||||
|
@ -0,0 +1,36 @@
|
||||
"""Add company table
|
||||
|
||||
Revision ID: 65ce764b8f7e
|
||||
Revises: 3781e22d8b01
|
||||
Create Date: 2025-02-10 12:34:00.054493
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '65ce764b8f7e'
|
||||
down_revision: Union[str, None] = '3781e22d8b01'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'company',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('profile_image_url', sa.Text(), nullable=True),
|
||||
sa.Column('default_model', sa.String(), server_default='GPT 4o'),
|
||||
sa.Column('allowed_models', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('company')
|
@ -0,0 +1,60 @@
|
||||
"""Add company_id to user table
|
||||
|
||||
Revision ID: 9ca43b058511
|
||||
Revises: 65ce764b8f7e
|
||||
Create Date: 2025-02-10 13:11:02.691437
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9ca43b058511'
|
||||
down_revision: Union[str, None] = '65ce764b8f7e'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def column_exists(table, column):
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
columns = inspector.get_columns(table)
|
||||
return any(c["name"] == column for c in columns)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add company_id column if it doesn't exist
|
||||
if not column_exists("user", "company_id"):
|
||||
op.add_column('user', sa.Column('company_id', sa.String(), nullable=True))
|
||||
|
||||
# Make columns non-nullable
|
||||
with op.batch_alter_table('user', schema=None) as batch_op:
|
||||
batch_op.alter_column('company_id', nullable=False)
|
||||
|
||||
# Add foreign key constraint if it doesn't exist
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
foreign_keys = inspector.get_foreign_keys('user')
|
||||
if not any(fk['referred_table'] == 'company' for fk in foreign_keys):
|
||||
with op.batch_alter_table('user', schema=None) as batch_op:
|
||||
batch_op.create_foreign_key(
|
||||
'fk_user_company_id', 'company',
|
||||
['company_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table('user', schema=None) as batch_op:
|
||||
# Remove foreign key constraint if it exists
|
||||
foreign_keys = Inspector.from_engine(op.get_bind()).get_foreign_keys('user')
|
||||
if any(fk['name'] == 'fk_user_company_id' for fk in foreign_keys):
|
||||
batch_op.drop_constraint('fk_user_company_id', type_='foreignkey')
|
||||
|
||||
# Drop columns if they exist
|
||||
if column_exists("user", "company_id"):
|
||||
batch_op.drop_column('company_id')
|
@ -3,7 +3,7 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.users import UserModel, Users
|
||||
from beyond_the_loop.models.users import UserModel, Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Boolean, Column, String, Text
|
||||
@ -91,7 +91,8 @@ class SignupForm(BaseModel):
|
||||
|
||||
|
||||
class AddUserForm(SignupForm):
|
||||
role: Optional[str] = "pending"
|
||||
role: Optional[str] = "pending",
|
||||
company_id: str
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
@ -100,6 +101,7 @@ class AuthsTable:
|
||||
email: str,
|
||||
password: str,
|
||||
name: str,
|
||||
company_id: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
@ -116,7 +118,7 @@ class AuthsTable:
|
||||
db.add(result)
|
||||
|
||||
user = Users.insert_new_user(
|
||||
id, name, email, profile_image_url, role, oauth_sub
|
||||
id, name, email, company_id, profile_image_url, role, oauth_sub
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.users import Users
|
||||
from beyond_the_loop.models.users import Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||
|
@ -8,7 +8,7 @@ from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from beyond_the_loop.models.users import Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
@ -2,7 +2,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from beyond_the_loop.models.users import Users, UserResponse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from beyond_the_loop.models.users import Users, UserResponse
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
@ -15,7 +15,7 @@ from langchain_core.documents import Document
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
from open_webui.models.users import UserModel
|
||||
from beyond_the_loop.models.users import UserModel
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
|
@ -18,7 +18,7 @@ from open_webui.models.auths import (
|
||||
UpdateProfileForm,
|
||||
UserResponse,
|
||||
)
|
||||
from open_webui.models.users import Users
|
||||
from beyond_the_loop.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
@ -258,7 +258,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, company_id="NO_COMPANY", role=role
|
||||
)
|
||||
|
||||
if not user:
|
||||
@ -451,6 +451,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
form_data.email.lower(),
|
||||
hashed,
|
||||
form_data.name,
|
||||
"NO_COMPANY",
|
||||
form_data.profile_image_url,
|
||||
role,
|
||||
)
|
||||
@ -564,6 +565,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
form_data.email.lower(),
|
||||
hashed,
|
||||
form_data.name,
|
||||
form_data.company_id,
|
||||
form_data.profile_image_url,
|
||||
form_data.role,
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.socket.main import sio, get_user_ids_from_room
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
from beyond_the_loop.models.users import Users, UserNameResponse
|
||||
|
||||
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
|
||||
from open_webui.models.messages import (
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from beyond_the_loop.models.users import Users, UserModel
|
||||
from open_webui.models.feedbacks import (
|
||||
FeedbackModel,
|
||||
FeedbackResponse,
|
||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from beyond_the_loop.models.users import Users
|
||||
from open_webui.models.groups import (
|
||||
Groups,
|
||||
GroupForm,
|
||||
|
@ -25,7 +25,7 @@ from open_webui.utils.access_control import has_access, has_permission
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.models.models import Models, ModelForm
|
||||
from beyond_the_loop.models.models import Models, ModelForm
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.models.models import (
|
||||
from beyond_the_loop.models.models import (
|
||||
ModelForm,
|
||||
ModelModel,
|
||||
ModelResponse,
|
||||
ModelUserResponse,
|
||||
ModelCompanyResponse,
|
||||
Models,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@ -23,7 +23,7 @@ router = APIRouter()
|
||||
###########################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ModelUserResponse])
|
||||
@router.get("/", response_model=list[ModelCompanyResponse])
|
||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
return Models.get_models()
|
||||
|
@ -32,7 +32,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
||||
from open_webui.models.models import Models
|
||||
from beyond_the_loop.models.models import Models
|
||||
from open_webui.utils.misc import (
|
||||
calculate_sha256,
|
||||
)
|
||||
|
@ -16,7 +16,7 @@ from fastapi.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from open_webui.models.models import Models
|
||||
from beyond_the_loop.models.models import Models
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
)
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import (
|
||||
from beyond_the_loop.models.users import (
|
||||
UserModel,
|
||||
UserRoleUpdateForm,
|
||||
Users,
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
from beyond_the_loop.models.users import Users, UserNameResponse
|
||||
from open_webui.models.channels import Channels
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
|
@ -32,6 +32,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password=get_password_hash("old_password"),
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="user",
|
||||
)
|
||||
@ -53,6 +54,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password=get_password_hash("old_password"),
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="user",
|
||||
)
|
||||
@ -80,6 +82,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password=get_password_hash("password"),
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="user",
|
||||
)
|
||||
@ -142,6 +145,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
)
|
||||
@ -159,6 +163,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
)
|
||||
@ -174,6 +179,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
)
|
||||
@ -190,6 +196,7 @@ class TestAuths(AbstractPostgresTest):
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
company_id="1",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from beyond_the_loop.models.users import Users, UserModel
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import jwt
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional, Union, List, Dict
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from beyond_the_loop.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import WEBUI_SECRET_KEY
|
||||
|
@ -12,7 +12,7 @@ from fastapi import Request
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from beyond_the_loop.models.users import UserModel
|
||||
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
@ -34,7 +34,7 @@ from open_webui.routers.pipelines import (
|
||||
)
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from beyond_the_loop.models.models import Models
|
||||
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
|
@ -25,7 +25,7 @@ from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import Users
|
||||
from beyond_the_loop.models.users import Users
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
@ -44,9 +44,9 @@ from open_webui.routers.images import image_generations, GenerateImageForm
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from beyond_the_loop.models.users import UserModel
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from beyond_the_loop.models.models import Models
|
||||
|
||||
from open_webui.retrieval.utils import get_sources_from_files
|
||||
|
||||
|
@ -10,7 +10,7 @@ from open_webui.functions import get_function_models
|
||||
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from beyond_the_loop.models.models import Models
|
||||
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
|
@ -13,7 +13,7 @@ from fastapi import (
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.users import Users
|
||||
from beyond_the_loop.models.users import Users
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm
|
||||
from open_webui.config import (
|
||||
DEFAULT_USER_ROLE,
|
||||
@ -292,6 +292,7 @@ class OAuthManager:
|
||||
str(uuid.uuid4())
|
||||
), # Random password, not used
|
||||
name=name,
|
||||
company_id="NO_COMPANY",
|
||||
profile_image_url=picture_url,
|
||||
role=role,
|
||||
oauth_sub=provider_sub,
|
||||
|
@ -11,7 +11,7 @@ from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
|
||||
from open_webui.models.tools import Tools
|
||||
from open_webui.models.users import UserModel
|
||||
from beyond_the_loop.models.users import UserModel
|
||||
from open_webui.utils.plugin import load_tools_module_by_id
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
Loading…
Reference in New Issue
Block a user