refac: switch to meta and params, remove source

This commit is contained in:
Jun Siang Cheah 2024-05-21 22:05:16 +01:00
parent 7ccef3e77a
commit f21c8626d6
12 changed files with 70 additions and 107 deletions

View File

@ -78,9 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.MODEL_CONFIG = [
model.to_form() for model in Models.get_all_models_by_source("litellm")
]
app.state.MODEL_CONFIG = Models.get_all_models()
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config

View File

@ -66,9 +66,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.MODEL_CONFIG = [
model.to_form() for model in Models.get_all_models_by_source("ollama")
]
app.state.MODEL_CONFIG = Models.get_all_models()
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}

View File

@ -52,9 +52,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.MODEL_CONFIG = [
model.to_form() for model in Models.get_all_models_by_source("openai")
]
app.state.MODEL_CONFIG = Models.get_all_models()
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API

View File

@ -1,3 +1,5 @@
import json
from peewee import *
from peewee_migrate import Router
from playhouse.db_url import connect
@ -8,6 +10,16 @@ import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
def db_value(self, value):
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file

View File

@ -1,4 +1,4 @@
"""Peewee migrations -- 008_add_models.py.
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
@ -39,20 +39,15 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Model(pw.Model):
id = pw.TextField()
source = pw.TextField()
base_model = pw.TextField(null=True)
id = pw.TextField(unique=True)
meta = pw.TextField()
base_model_id = pw.TextField(null=True)
name = pw.TextField()
params = pw.TextField()
class Meta:
table_name = "model"
indexes = (
# Create a unique index on the id, source columns
(("id", "source"), True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""

View File

@ -6,7 +6,7 @@ import peewee as pw
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel
from apps.web.internal.db import DB
from apps.web.internal.db import DB, JSONField
from config import SRC_LOG_LEVELS
@ -22,6 +22,12 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
# ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel):
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
description: str
"""
User-facing description of the model.
@ -34,50 +40,42 @@ class ModelParams(BaseModel):
class Model(pw.Model):
id = pw.TextField()
id = pw.TextField(unique=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
source = pw.TextField()
meta = JSONField()
"""
The source of the model, e.g., ollama, openai, or litellm.
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
base_model = pw.TextField(null=True)
base_model_id = pw.TextField(null=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
Currently unused - but will be used to support Modelfile like behaviour in the future
An optional pointer to the actual model that should be used when proxying requests.
Currently unused - but will be used to support Modelfile like behaviour in the future
"""
name = pw.TextField()
"""
The human-readable display name of the model.
The human-readable display name of the model.
"""
params = pw.TextField()
params = JSONField()
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
class Meta:
database = DB
indexes = (
# Create a unique index on the id, source columns
(("id", "source"), True),
)
class ModelModel(BaseModel):
id: str
source: str
base_model: Optional[str] = None
meta: ModelMeta
base_model_id: Optional[str] = None
name: str
params: str
def to_form(self) -> "ModelForm":
return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})
params: ModelParams
####################
@ -85,17 +83,6 @@ class ModelModel(BaseModel):
####################
class ModelForm(BaseModel):
id: str
source: str
base_model: Optional[str] = None
name: str
params: dict
def to_db_model(self) -> ModelModel:
return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})
class ModelsTable:
def __init__(
@ -108,51 +95,37 @@ class ModelsTable:
def get_all_models(self) -> list[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def get_all_models_by_source(self, source: str) -> list[ModelModel]:
return [
ModelModel(**model_to_dict(model))
for model in Model.select().where(Model.source == source)
]
def update_all_models(self, models: list[ModelForm]) -> bool:
def update_all_models(self, models: list[ModelModel]) -> bool:
try:
with self.db.atomic():
# Fetch current models from the database
current_models = self.get_all_models()
current_model_dict = {
(model.id, model.source): model for model in current_models
}
current_model_dict = {model.id: model for model in current_models}
# Create a set of model IDs and sources from the current models and the new models
# Create a set of model IDs from the current models and the new models
current_model_keys = set(current_model_dict.keys())
new_model_keys = set((model.id, model.source) for model in models)
new_model_keys = set(model.id for model in models)
# Determine which models need to be created, updated, or deleted
models_to_create = [
model
for model in models
if (model.id, model.source) not in current_model_keys
model for model in models if model.id not in current_model_keys
]
models_to_update = [
model
for model in models
if (model.id, model.source) in current_model_keys
model for model in models if model.id in current_model_keys
]
models_to_delete = current_model_keys - new_model_keys
# Perform the necessary database operations
for model in models_to_create:
Model.create(**model.to_db_model().model_dump())
Model.create(**model.model_dump())
for model in models_to_update:
Model.update(**model.to_db_model().model_dump()).where(
(Model.id == model.id) & (Model.source == model.source)
Model.update(**model.model_dump()).where(
Model.id == model.id
).execute()
for model_id, model_source in models_to_delete:
Model.delete().where(
(Model.id == model_id) & (Model.source == model_source)
).execute()
Model.delete().where(Model.id == model_id).execute()
return True
except Exception as e:

View File

@ -37,7 +37,7 @@ import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.web.models.models import Models, ModelModel, ModelForm
from apps.web.models.models import Models, ModelModel
from utils.utils import get_admin_user
from apps.rag.utils import rag_messages
@ -112,7 +112,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.MODEL_CONFIG = [model.to_form() for model in Models.get_all_models()]
app.state.MODEL_CONFIG = Models.get_all_models()
app.state.config.WEBHOOK_URL = WEBHOOK_URL
@ -320,7 +320,7 @@ async def update_model_filter_config(
class SetModelConfigForm(BaseModel):
models: List[ModelForm]
models: List[ModelModel]
@app.post("/api/config/models")
@ -333,19 +333,10 @@ async def update_model_config(
detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
)
ollama_app.state.MODEL_CONFIG = [
model for model in form_data.models if model.source == "ollama"
]
openai_app.state.MODEL_CONFIG = [
model for model in form_data.models if model.source == "openai"
]
litellm_app.state.MODEL_CONFIG = [
model for model in form_data.models if model.source == "litellm"
]
app.state.MODEL_CONFIG = [model for model in form_data.models]
ollama_app.state.MODEL_CONFIG = form_data.models
openai_app.state.MODEL_CONFIG = form_data.models
litellm_app.state.MODEL_CONFIG = form_data.models
app.state.MODEL_CONFIG = form_data.models
return {"models": app.state.MODEL_CONFIG}

View File

@ -227,16 +227,18 @@ export const getModelConfig = async (token: string): Promise<GlobalModelConfig>
export interface ModelConfig {
id: string;
name: string;
source: string;
base_model?: string;
meta: ModelMeta;
base_model_id?: string;
params: ModelParams;
}
export interface ModelParams {
export interface ModelMeta {
description?: string;
vision_capable?: boolean;
}
export interface ModelParams {}
export type GlobalModelConfig = ModelConfig[];
export const updateModelConfig = async (token: string, config: GlobalModelConfig) => {

View File

@ -343,7 +343,7 @@
const hasImages = messages.some((message) =>
message.files?.some((file) => file.type === 'image')
);
if (hasImages && !(model.custom_info?.params.vision_capable ?? true)) {
if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) {
toast.error(
$i18n.t('Model {{modelName}} is not vision capable', {
modelName: model.custom_info?.name ?? model.name ?? model.id

View File

@ -359,7 +359,7 @@
if (!model) {
continue;
}
if (model.custom_info?.params.vision_capable ?? true) {
if (model.custom_info?.meta.vision_capable ?? true) {
visionCapableCount++;
}
}

View File

@ -307,10 +307,10 @@
</div>
</Tooltip>
{/if}
{#if item.info?.custom_info?.params.description}
{#if item.info?.custom_info?.meta.description}
<Tooltip
content={`${sanitizeResponseContent(
item.info.custom_info?.params.description
item.info.custom_info?.meta.description
).replaceAll('\n', '<br>')}`}
>
<div class="">

View File

@ -80,8 +80,8 @@
const model = $models.find((m) => m.id === selectedModelId);
if (model) {
modelName = model.custom_info?.name ?? model.name;
modelDescription = model.custom_info?.params.description ?? '';
modelIsVisionCapable = model.custom_info?.params.vision_capable ?? false;
modelDescription = model.custom_info?.meta.description ?? '';
modelIsVisionCapable = model.custom_info?.meta.vision_capable ?? false;
}
};
@ -518,18 +518,16 @@
if (!model) {
return;
}
const modelSource =
'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
// Remove any existing config
modelConfig = modelConfig.filter(
(m) => !(m.id === selectedModelId && m.source === modelSource)
(m) => !(m.id === selectedModelId)
);
// Add new config
modelConfig.push({
id: selectedModelId,
name: modelName,
source: modelSource,
params: {
params: {},
meta: {
description: modelDescription,
vision_capable: modelIsVisionCapable
}
@ -549,10 +547,8 @@
if (!model) {
return;
}
const modelSource =
'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
modelConfig = modelConfig.filter(
(m) => !(m.id === selectedModelId && m.source === modelSource)
(m) => !(m.id === selectedModelId)
);
await updateModelConfig(localStorage.token, modelConfig);
toast.success(