Add company model

This commit is contained in:
Phil Szalay 2025-02-13 17:48:02 +01:00
parent c49a7776e9
commit e95b753e71

View File

@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict
from typing import Optional
from sqlalchemy.orm import relationship
from sqlalchemy import String, Column, Text
from sqlalchemy import String, Column, Text, Integer
from open_webui.internal.db import get_db, Base
@ -22,6 +22,7 @@ class Company(Base):
profile_image_url = Column(Text, nullable=True)
default_model = Column(String, default="GPT 4o")
allowed_models = Column(Text, nullable=True)
users = relationship("User", back_populates="company", cascade="all, delete-orphan")
class CompanyModel(BaseModel):
@ -30,6 +31,7 @@ class CompanyModel(BaseModel):
profile_image_url: Optional[str] = None
default_model: Optional[str] = "GPT 4o"
allowed_models: Optional[str] = None
token_balance: Optional[int] = 0
model_config = ConfigDict(from_attributes=True)
@ -136,6 +138,38 @@ class CompanyTable:
# Handle exceptions if any
return False
def update_token_balance(self, company_id: str, tokens_used: int) -> bool:
"""Update company's token balance by subtracting tokens used"""
with get_db() as db:
company = db.query(Company).filter(Company.id == company_id).first()
if company and company.token_balance is not None:
company.token_balance -= tokens_used
db.commit()
return True
return False
def add_token_balance(self, company_id: str, tokens_to_add: int) -> bool:
"""Add tokens to company's balance"""
with get_db() as db:
company = db.query(Company).filter(Company.id == company_id).first()
if company:
if company.token_balance is None:
company.token_balance = tokens_to_add
else:
company.token_balance += tokens_to_add
db.commit()
return True
return False
def get_token_balance(self, company_id: str) -> Optional[int]:
"""Get company's current token balance"""
with get_db() as db:
company = db.query(Company).filter(Company.id == company_id).first()
return company.token_balance if company else None
def has_sufficient_tokens(self, company_id: str, required_tokens: int) -> bool:
"""Check if company has sufficient tokens for an operation"""
balance = self.get_token_balance(company_id)
return balance is None or balance >= required_tokens # None means unlimited
Companies = CompanyTable()