mirror of
https://github.com/open-webui/open-webui
synced 2025-04-10 15:45:45 +00:00
Add company model
This commit is contained in:
parent
c49a7776e9
commit
e95b753e71
@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import String, Column, Text
|
||||
from sqlalchemy import String, Column, Text, Integer
|
||||
|
||||
from open_webui.internal.db import get_db, Base
|
||||
|
||||
@ -22,6 +22,7 @@ class Company(Base):
|
||||
profile_image_url = Column(Text, nullable=True)
|
||||
default_model = Column(String, default="GPT 4o")
|
||||
allowed_models = Column(Text, nullable=True)
|
||||
|
||||
users = relationship("User", back_populates="company", cascade="all, delete-orphan")
|
||||
|
||||
class CompanyModel(BaseModel):
|
||||
@ -30,6 +31,7 @@ class CompanyModel(BaseModel):
|
||||
profile_image_url: Optional[str] = None
|
||||
default_model: Optional[str] = "GPT 4o"
|
||||
allowed_models: Optional[str] = None
|
||||
token_balance: Optional[int] = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@ -136,6 +138,38 @@ class CompanyTable:
|
||||
# Handle exceptions if any
|
||||
return False
|
||||
|
||||
|
||||
def update_token_balance(self, company_id: str, tokens_used: int) -> bool:
|
||||
"""Update company's token balance by subtracting tokens used"""
|
||||
with get_db() as db:
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if company and company.token_balance is not None:
|
||||
company.token_balance -= tokens_used
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_token_balance(self, company_id: str, tokens_to_add: int) -> bool:
|
||||
"""Add tokens to company's balance"""
|
||||
with get_db() as db:
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if company:
|
||||
if company.token_balance is None:
|
||||
company.token_balance = tokens_to_add
|
||||
else:
|
||||
company.token_balance += tokens_to_add
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_token_balance(self, company_id: str) -> Optional[int]:
|
||||
"""Get company's current token balance"""
|
||||
with get_db() as db:
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
return company.token_balance if company else None
|
||||
|
||||
def has_sufficient_tokens(self, company_id: str, required_tokens: int) -> bool:
|
||||
"""Check if company has sufficient tokens for an operation"""
|
||||
balance = self.get_token_balance(company_id)
|
||||
return balance is None or balance >= required_tokens # None means unlimited
|
||||
|
||||
Companies = CompanyTable()
|
Loading…
Reference in New Issue
Block a user