feat: modelfiles to models

This commit is contained in:
Timothy J. Baek
2024-05-24 00:26:00 -07:00
parent 17e4be49c0
commit 4d57e08b38
18 changed files with 257 additions and 310 deletions

View File

@@ -6,7 +6,7 @@ from apps.web.routers import (
users,
chats,
documents,
modelfiles,
models,
prompts,
configs,
memories,
@@ -56,11 +56,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])

View File

@@ -1,3 +1,11 @@
################################################################################
# DEPRECATION NOTICE #
# #
# This file has been deprecated since version 0.2.0. #
# #
################################################################################
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict

View File

@@ -3,13 +3,18 @@ import logging
from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
import time
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -20,10 +25,8 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
# ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
pass
@@ -55,7 +58,6 @@ class Model(pw.Model):
base_model_id = pw.TextField(null=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
Currently unused - but will be used to support Modelfile like behaviour in the future
"""
name = pw.TextField()
@@ -73,8 +75,8 @@ class Model(pw.Model):
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
@@ -83,16 +85,36 @@ class Model(pw.Model):
class ModelModel(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ModelResponse(BaseModel):
id: str
name: str
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
meta: ModelMeta
params: ModelParams
class ModelsTable:
def __init__(
self,
@@ -101,44 +123,47 @@ class ModelsTable:
self.db = db
self.db.create_tables([Model])
def get_all_models(self) -> list[ModelModel]:
def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]:
try:
model = Model.create(
**{
**model.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
return ModelModel(**model_to_dict(model))
except:
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def update_all_models(self, models: list[ModelModel]) -> bool:
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with self.db.atomic():
# Fetch current models from the database
current_models = self.get_all_models()
current_model_dict = {model.id: model for model in current_models}
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except:
return None
# Create a set of model IDs from the current models and the new models
current_model_keys = set(current_model_dict.keys())
new_model_keys = set(model.id for model in models)
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
# Determine which models need to be created, updated, or deleted
models_to_create = [
model for model in models if model.id not in current_model_keys
]
models_to_update = [
model for model in models if model.id in current_model_keys
]
models_to_delete = current_model_keys - new_model_keys
# Perform the necessary database operations
for model in models_to_create:
Model.create(**model.model_dump())
for model in models_to_update:
Model.update(**model.model_dump()).where(
Model.id == model.id
).execute()
for model_id, model_source in models_to_delete:
Model.delete().where(Model.id == model_id).execute()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except:
return None
def delete_model_by_id(self, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
return True
except Exception as e:
log.exception(e)
except:
return False

View File

@@ -1,124 +0,0 @@
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.modelfiles import (
Modelfiles,
ModelfileForm,
ModelfileTagNameForm,
ModelfileUpdateForm,
ModelfileResponse,
)
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetModelfiles
############################
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
return Modelfiles.get_modelfiles(skip, limit)
############################
# CreateNewModelfile
############################
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelfileByTagName
############################
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelfileByTagName
############################
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeleteModelfileByTagName
############################
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View File

@@ -0,0 +1,89 @@
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
###########################
# getAllModels
###########################
@router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models()
############################
# AddNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(form_data: ModelForm, user=Depends(get_admin_user)):
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelById
############################
@router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeleteModelById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result

View File

@@ -320,33 +320,6 @@ async def update_model_filter_config(
}
class SetModelConfigForm(BaseModel):
models: List[ModelModel]
@app.post("/api/config/models")
async def update_model_config(
form_data: SetModelConfigForm, user=Depends(get_admin_user)
):
if not Models.update_all_models(form_data.models):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
)
ollama_app.state.MODEL_CONFIG = form_data.models
openai_app.state.MODEL_CONFIG = form_data.models
litellm_app.state.MODEL_CONFIG = form_data.models
app.state.MODEL_CONFIG = form_data.models
return {"models": app.state.MODEL_CONFIG}
@app.get("/api/config/models")
async def get_model_config(user=Depends(get_admin_user)):
return {"models": app.state.MODEL_CONFIG}
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {