mirror of
https://github.com/open-webui/open-webui
synced 2025-04-25 16:49:46 +00:00
237 lines
8.3 KiB
Python
237 lines
8.3 KiB
Python
import json
|
|
from pydantic import BaseModel, ConfigDict
|
|
from typing import Optional
|
|
|
|
from sqlalchemy.orm import relationship
|
|
from sqlalchemy import Integer, String, Column, Text, Boolean
|
|
|
|
from open_webui.internal.db import get_db, Base
|
|
|
|
# Constants
|
|
NO_COMPANY = "NO_COMPANY"
|
|
EIGHTY_PERCENT_CREDIT_LIMIT = 4000
|
|
|
|
####################
|
|
# Company DB Schema
|
|
####################
|
|
|
|
class Company(Base):
|
|
__tablename__ = "company"
|
|
|
|
id = Column(String, primary_key=True, unique=True)
|
|
name = Column(String, nullable=False)
|
|
profile_image_url = Column(Text, nullable=True)
|
|
default_model = Column(String, nullable=True)
|
|
allowed_models = Column(Text, nullable=True)
|
|
credit_balance = Column(Integer, default=0)
|
|
auto_recharge = Column(Boolean, default=False)
|
|
credit_card_number = Column(String, nullable=True)
|
|
|
|
users = relationship("User", back_populates="company", cascade="all, delete-orphan")
|
|
|
|
class CompanyModel(BaseModel):
|
|
id: str
|
|
name: str
|
|
profile_image_url: Optional[str] = None
|
|
default_model: Optional[str] = "GPT 4o"
|
|
allowed_models: Optional[str] = None
|
|
credit_balance: Optional[int] = 0
|
|
auto_recharge: Optional[bool] = False
|
|
credit_card_number: Optional[str] = None
|
|
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
####################
|
|
# Forms
|
|
####################
|
|
|
|
class CompanyModelForm(BaseModel):
|
|
id: str
|
|
model_id: str
|
|
|
|
class CompanyForm(BaseModel):
|
|
company: dict
|
|
|
|
|
|
class CompanyResponse(BaseModel):
|
|
id: str
|
|
name: str
|
|
profile_image_url: Optional[str] = None
|
|
default_model: Optional[str] = "GPT 4o"
|
|
allowed_models: Optional[str]
|
|
auto_recharge: bool
|
|
|
|
|
|
class CompanyTable:
|
|
def get_company_by_id(self, company_id: str):
|
|
try:
|
|
with get_db() as db:
|
|
company = db.query(Company).filter_by(id=company_id).first()
|
|
return CompanyModel.model_validate(company)
|
|
except Exception as e:
|
|
print(f"Error getting company: {e}")
|
|
return None
|
|
|
|
|
|
def update_company_by_id(self, id: str, updated: dict) -> Optional[CompanyModel]:
|
|
try:
|
|
with get_db() as db:
|
|
db.query(Company).filter_by(id=id).update(updated)
|
|
db.commit()
|
|
|
|
company = db.query(Company).filter_by(id=id).first()
|
|
return CompanyModel.model_validate(company)
|
|
|
|
except Exception as e:
|
|
print(f"Error updating company", e)
|
|
return None
|
|
|
|
|
|
def update_auto_recharge(self, company_id: str, auto_recharge: bool) -> Optional[CompanyModel]:
|
|
|
|
try:
|
|
with get_db() as db:
|
|
company = db.query(Company).filter_by(id=company_id).first()
|
|
if not company:
|
|
print(f"Company with ID {company_id} not found.")
|
|
return None
|
|
|
|
db.query(Company).filter_by(id=company_id).update({"auto_recharge": auto_recharge})
|
|
db.commit()
|
|
|
|
updated_company = db.query(Company).filter_by(id=company_id).first()
|
|
return CompanyModel.model_validate(updated_company)
|
|
|
|
except Exception as e:
|
|
print(f"Error updating auto_recharge for company {company_id}: {e}")
|
|
return None
|
|
|
|
|
|
def get_auto_recharge(self, company_id: str) -> Optional[bool]:
|
|
try:
|
|
with get_db() as db:
|
|
company = db.query(Company).filter_by(id=company_id).first()
|
|
if not company:
|
|
print(f"Company with ID {company_id} not found.")
|
|
return None
|
|
|
|
return company.auto_recharge
|
|
|
|
except Exception as e:
|
|
print(f"Error retrieving auto_recharge for company {company_id}: {e}")
|
|
return None
|
|
|
|
|
|
def add_model(self, company_id: str, model_id: str) -> bool:
|
|
try:
|
|
with get_db() as db:
|
|
# Fetch the company by its ID
|
|
company = db.query(Company).filter_by(id=company_id).first()
|
|
print("Company: ", company.allowed_models)
|
|
# If company doesn't exist, return False
|
|
if not company:
|
|
return None
|
|
|
|
company.allowed_models = '[]' if company.allowed_models is None else company.allowed_models
|
|
# Load current members from JSON
|
|
current_models = json.loads(company.allowed_models)
|
|
|
|
# If model_id is not already in the list, add it
|
|
if model_id not in current_models:
|
|
current_models.append(model_id)
|
|
|
|
payload = {"allowed_models": json.dumps(current_models)}
|
|
db.query(Company).filter_by(id=company_id).update(payload)
|
|
db.commit()
|
|
|
|
return True
|
|
else:
|
|
# Model already exists in the company
|
|
return False
|
|
except Exception as e:
|
|
# Handle exceptions if any
|
|
print("ERRRO::: ", e)
|
|
return False
|
|
|
|
def remove_model(self, company_id: str, model_id: str) -> bool:
|
|
try:
|
|
with get_db() as db:
|
|
# Fetch the company by its ID
|
|
company = db.query(Company).filter_by(id=company_id).first()
|
|
|
|
# If company doesn't exist, return False
|
|
if not company:
|
|
return None
|
|
|
|
# Load current members from JSON
|
|
current_models = json.loads(company.allowed_models)
|
|
|
|
# If model_id is in the list, remove it
|
|
if model_id in current_models:
|
|
current_models.remove(model_id)
|
|
|
|
payload = {"allowed_models": json.dumps(current_models)}
|
|
db.query(Company).filter_by(id=company_id).update(payload)
|
|
db.commit()
|
|
return True
|
|
else:
|
|
# Member not found in the company
|
|
return False
|
|
except Exception as e:
|
|
# Handle exceptions if any
|
|
return False
|
|
|
|
def update_credit_balance(self, company_id: str, credits_used: int) -> bool:
|
|
"""Update company's credit balance by subtracting credits used"""
|
|
with get_db() as db:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if company and company.credit_balance is not None:
|
|
company.credit_balance -= credits_used
|
|
db.commit()
|
|
return True
|
|
return False
|
|
|
|
def add_credit_balance(self, company_id: str, credits_to_add: int) -> bool:
|
|
"""Add credits to company's balance"""
|
|
with get_db() as db:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if company:
|
|
if company.credit_balance is None:
|
|
company.credit_balance = credits_to_add
|
|
else:
|
|
company.credit_balance += credits_to_add
|
|
db.commit()
|
|
return True
|
|
return False
|
|
|
|
def subtract_credit_balance(self, company_id: str, credits_to_subtract: int) -> bool:
|
|
"""Subtract credits from company's balance"""
|
|
with get_db() as db:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if company:
|
|
if company.credit_balance is not None and company.credit_balance >= credits_to_subtract:
|
|
company.credit_balance -= credits_to_subtract
|
|
db.commit()
|
|
return True
|
|
return False
|
|
|
|
def get_credit_balance(self, company_id: str) -> Optional[int]:
|
|
"""Get company's current credit balance"""
|
|
with get_db() as db:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
return company.credit_balance if company else None
|
|
|
|
def create_company(self, company_data: dict) -> Optional[CompanyModel]:
|
|
"""Create a new company"""
|
|
try:
|
|
with get_db() as db:
|
|
company = Company(**company_data)
|
|
db.add(company)
|
|
db.commit()
|
|
db.refresh(company)
|
|
return CompanyModel.model_validate(company)
|
|
except Exception as e:
|
|
print(f"Error creating company: {e}")
|
|
return None
|
|
|
|
Companies = CompanyTable() |