diff --git a/backend/beyond_the_loop/models/companies.py b/backend/beyond_the_loop/models/companies.py index 269b785ee..6c1e5a630 100644 --- a/backend/beyond_the_loop/models/companies.py +++ b/backend/beyond_the_loop/models/companies.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict from typing import Optional from sqlalchemy.orm import relationship -from sqlalchemy import String, Column, Text +from sqlalchemy import String, Column, Text, Integer from open_webui.internal.db import get_db, Base @@ -22,6 +22,7 @@ class Company(Base): profile_image_url = Column(Text, nullable=True) default_model = Column(String, default="GPT 4o") allowed_models = Column(Text, nullable=True) + users = relationship("User", back_populates="company", cascade="all, delete-orphan") class CompanyModel(BaseModel): @@ -30,6 +31,7 @@ class CompanyModel(BaseModel): profile_image_url: Optional[str] = None default_model: Optional[str] = "GPT 4o" allowed_models: Optional[str] = None + token_balance: Optional[int] = 0 model_config = ConfigDict(from_attributes=True) @@ -136,6 +138,38 @@ class CompanyTable: # Handle exceptions if any return False - + def update_token_balance(self, company_id: str, tokens_used: int) -> bool: + """Update company's token balance by subtracting tokens used""" + with get_db() as db: + company = db.query(Company).filter(Company.id == company_id).first() + if company and company.token_balance is not None: + company.token_balance -= tokens_used + db.commit() + return True + return False + + def add_token_balance(self, company_id: str, tokens_to_add: int) -> bool: + """Add tokens to company's balance""" + with get_db() as db: + company = db.query(Company).filter(Company.id == company_id).first() + if company: + if company.token_balance is None: + company.token_balance = tokens_to_add + else: + company.token_balance += tokens_to_add + db.commit() + return True + return False + + def get_token_balance(self, company_id: str) -> Optional[int]: + """Get company's current token balance""" + with get_db() as db: + company = db.query(Company).filter(Company.id == company_id).first() + return company.token_balance if company else None + + def has_sufficient_tokens(self, company_id: str, required_tokens: int) -> bool: + """Check if company has sufficient tokens for an operation""" + balance = self.get_token_balance(company_id) + return balance is None or balance >= required_tokens # None means unlimited Companies = CompanyTable() \ No newline at end of file