import json from pydantic import BaseModel, ConfigDict from typing import Optional from sqlalchemy.orm import relationship from sqlalchemy import Integer, String, Column, Text, Boolean from open_webui.internal.db import get_db, Base # Constants NO_COMPANY = "NO_COMPANY" EIGHTY_PERCENT_CREDIT_LIMIT = 4000 #################### # Company DB Schema #################### class Company(Base): __tablename__ = "company" id = Column(String, primary_key=True, unique=True) name = Column(String, nullable=False) profile_image_url = Column(Text, nullable=True) default_model = Column(String, nullable=True) allowed_models = Column(Text, nullable=True) credit_balance = Column(Integer, default=0) auto_recharge = Column(Boolean, default=False) credit_card_number = Column(String, nullable=True) users = relationship("User", back_populates="company", cascade="all, delete-orphan") class CompanyModel(BaseModel): id: str name: str profile_image_url: Optional[str] = None default_model: Optional[str] = "GPT 4o" allowed_models: Optional[str] = None credit_balance: Optional[int] = 0 auto_recharge: Optional[bool] = False credit_card_number: Optional[str] = None model_config = ConfigDict(from_attributes=True) #################### # Forms #################### class CompanyModelForm(BaseModel): id: str model_id: str class CompanyForm(BaseModel): company: dict class CompanyResponse(BaseModel): id: str name: str profile_image_url: Optional[str] = None default_model: Optional[str] = "GPT 4o" allowed_models: Optional[str] auto_recharge: bool class CompanyTable: def get_company_by_id(self, company_id: str): try: with get_db() as db: company = db.query(Company).filter_by(id=company_id).first() return CompanyModel.model_validate(company) except Exception as e: print(f"Error getting company: {e}") return None def update_company_by_id(self, id: str, updated: dict) -> Optional[CompanyModel]: try: with get_db() as db: db.query(Company).filter_by(id=id).update(updated) db.commit() company = db.query(Company).filter_by(id=id).first() return CompanyModel.model_validate(company) except Exception as e: print(f"Error updating company", e) return None def update_auto_recharge(self, company_id: str, auto_recharge: bool) -> Optional[CompanyModel]: try: with get_db() as db: company = db.query(Company).filter_by(id=company_id).first() if not company: print(f"Company with ID {company_id} not found.") return None db.query(Company).filter_by(id=company_id).update({"auto_recharge": auto_recharge}) db.commit() updated_company = db.query(Company).filter_by(id=company_id).first() return CompanyModel.model_validate(updated_company) except Exception as e: print(f"Error updating auto_recharge for company {company_id}: {e}") return None def get_auto_recharge(self, company_id: str) -> Optional[bool]: try: with get_db() as db: company = db.query(Company).filter_by(id=company_id).first() if not company: print(f"Company with ID {company_id} not found.") return None return company.auto_recharge except Exception as e: print(f"Error retrieving auto_recharge for company {company_id}: {e}") return None def add_model(self, company_id: str, model_id: str) -> bool: try: with get_db() as db: # Fetch the company by its ID company = db.query(Company).filter_by(id=company_id).first() print("Company: ", company.allowed_models) # If company doesn't exist, return False if not company: return None company.allowed_models = '[]' if company.allowed_models is None else company.allowed_models # Load current members from JSON current_models = json.loads(company.allowed_models) # If model_id is not already in the list, add it if model_id not in current_models: current_models.append(model_id) payload = {"allowed_models": json.dumps(current_models)} db.query(Company).filter_by(id=company_id).update(payload) db.commit() return True else: # Model already exists in the company return False except Exception as e: # Handle exceptions if any print("ERRRO::: ", e) return False def remove_model(self, company_id: str, model_id: str) -> bool: try: with get_db() as db: # Fetch the company by its ID company = db.query(Company).filter_by(id=company_id).first() # If company doesn't exist, return False if not company: return None # Load current members from JSON current_models = json.loads(company.allowed_models) # If model_id is in the list, remove it if model_id in current_models: current_models.remove(model_id) payload = {"allowed_models": json.dumps(current_models)} db.query(Company).filter_by(id=company_id).update(payload) db.commit() return True else: # Member not found in the company return False except Exception as e: # Handle exceptions if any return False def update_credit_balance(self, company_id: str, credits_used: int) -> bool: """Update company's credit balance by subtracting credits used""" with get_db() as db: company = db.query(Company).filter(Company.id == company_id).first() if company and company.credit_balance is not None: company.credit_balance -= credits_used db.commit() return True return False def add_credit_balance(self, company_id: str, credits_to_add: int) -> bool: """Add credits to company's balance""" with get_db() as db: company = db.query(Company).filter(Company.id == company_id).first() if company: if company.credit_balance is None: company.credit_balance = credits_to_add else: company.credit_balance += credits_to_add db.commit() return True return False def subtract_credit_balance(self, company_id: str, credits_to_subtract: int) -> bool: """Subtract credits from company's balance""" with get_db() as db: company = db.query(Company).filter(Company.id == company_id).first() if company: if company.credit_balance is not None and company.credit_balance >= credits_to_subtract: company.credit_balance -= credits_to_subtract db.commit() return True return False def get_credit_balance(self, company_id: str) -> Optional[int]: """Get company's current credit balance""" with get_db() as db: company = db.query(Company).filter(Company.id == company_id).first() return company.credit_balance if company else None def create_company(self, company_data: dict) -> Optional[CompanyModel]: """Create a new company""" try: with get_db() as db: company = Company(**company_data) db.add(company) db.commit() db.refresh(company) return CompanyModel.model_validate(company) except Exception as e: print(f"Error creating company: {e}") return None Companies = CompanyTable()