refac: switch to meta and params, remove source

This commit is contained in:
Jun Siang Cheah 2024-05-21 22:05:16 +01:00
parent 7ccef3e77a
commit f21c8626d6
12 changed files with 70 additions and 107 deletions

View File

@ -78,9 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.MODEL_CONFIG = [ app.state.MODEL_CONFIG = Models.get_all_models()
model.to_form() for model in Models.get_all_models_by_source("litellm")
]
app.state.ENABLE = ENABLE_LITELLM app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config app.state.CONFIG = litellm_config

View File

@ -66,9 +66,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.MODEL_CONFIG = [ app.state.MODEL_CONFIG = Models.get_all_models()
model.to_form() for model in Models.get_all_models_by_source("ollama")
]
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}

View File

@ -52,9 +52,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.MODEL_CONFIG = [ app.state.MODEL_CONFIG = Models.get_all_models()
model.to_form() for model in Models.get_all_models_by_source("openai")
]
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API

View File

@ -1,3 +1,5 @@
import json
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.db_url import connect from playhouse.db_url import connect
@ -8,6 +10,16 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
def db_value(self, value):
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
# Check if the file exists # Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file

View File

@ -1,4 +1,4 @@
"""Peewee migrations -- 008_add_models.py. """Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name):: Some examples (model - class or model name)::
@ -39,20 +39,15 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model @migrator.create_model
class Model(pw.Model): class Model(pw.Model):
id = pw.TextField() id = pw.TextField(unique=True)
source = pw.TextField() meta = pw.TextField()
base_model = pw.TextField(null=True) base_model_id = pw.TextField(null=True)
name = pw.TextField() name = pw.TextField()
params = pw.TextField() params = pw.TextField()
class Meta: class Meta:
table_name = "model" table_name = "model"
indexes = (
# Create a unique index on the id, source columns
(("id", "source"), True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""

View File

@ -6,7 +6,7 @@ import peewee as pw
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel from pydantic import BaseModel
from apps.web.internal.db import DB from apps.web.internal.db import DB, JSONField
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -22,6 +22,12 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
# ModelParams is a model for the data stored in the params field of the Model table # ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference # It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel): class ModelParams(BaseModel):
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
description: str description: str
""" """
User-facing description of the model. User-facing description of the model.
@ -34,50 +40,42 @@ class ModelParams(BaseModel):
class Model(pw.Model): class Model(pw.Model):
id = pw.TextField() id = pw.TextField(unique=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
source = pw.TextField() meta = JSONField()
""" """
The source of the model, e.g., ollama, openai, or litellm. Holds a JSON encoded blob of metadata, see `ModelMeta`.
""" """
base_model = pw.TextField(null=True) base_model_id = pw.TextField(null=True)
""" """
An optional pointer to the actual model that should be used when proxying requests. An optional pointer to the actual model that should be used when proxying requests.
Currently unused - but will be used to support Modelfile like behaviour in the future Currently unused - but will be used to support Modelfile like behaviour in the future
""" """
name = pw.TextField() name = pw.TextField()
""" """
The human-readable display name of the model. The human-readable display name of the model.
""" """
params = pw.TextField() params = JSONField()
""" """
Holds a JSON encoded blob of parameters, see `ModelParams`. Holds a JSON encoded blob of parameters, see `ModelParams`.
""" """
class Meta: class Meta:
database = DB database = DB
indexes = (
# Create a unique index on the id, source columns
(("id", "source"), True),
)
class ModelModel(BaseModel): class ModelModel(BaseModel):
id: str id: str
source: str meta: ModelMeta
base_model: Optional[str] = None base_model_id: Optional[str] = None
name: str name: str
params: str params: ModelParams
def to_form(self) -> "ModelForm":
return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})
#################### ####################
@ -85,17 +83,6 @@ class ModelModel(BaseModel):
#################### ####################
class ModelForm(BaseModel):
id: str
source: str
base_model: Optional[str] = None
name: str
params: dict
def to_db_model(self) -> ModelModel:
return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})
class ModelsTable: class ModelsTable:
def __init__( def __init__(
@ -108,51 +95,37 @@ class ModelsTable:
def get_all_models(self) -> list[ModelModel]: def get_all_models(self) -> list[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()] return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def get_all_models_by_source(self, source: str) -> list[ModelModel]: def update_all_models(self, models: list[ModelModel]) -> bool:
return [
ModelModel(**model_to_dict(model))
for model in Model.select().where(Model.source == source)
]
def update_all_models(self, models: list[ModelForm]) -> bool:
try: try:
with self.db.atomic(): with self.db.atomic():
# Fetch current models from the database # Fetch current models from the database
current_models = self.get_all_models() current_models = self.get_all_models()
current_model_dict = { current_model_dict = {model.id: model for model in current_models}
(model.id, model.source): model for model in current_models
}
# Create a set of model IDs and sources from the current models and the new models # Create a set of model IDs from the current models and the new models
current_model_keys = set(current_model_dict.keys()) current_model_keys = set(current_model_dict.keys())
new_model_keys = set((model.id, model.source) for model in models) new_model_keys = set(model.id for model in models)
# Determine which models need to be created, updated, or deleted # Determine which models need to be created, updated, or deleted
models_to_create = [ models_to_create = [
model model for model in models if model.id not in current_model_keys
for model in models
if (model.id, model.source) not in current_model_keys
] ]
models_to_update = [ models_to_update = [
model model for model in models if model.id in current_model_keys
for model in models
if (model.id, model.source) in current_model_keys
] ]
models_to_delete = current_model_keys - new_model_keys models_to_delete = current_model_keys - new_model_keys
# Perform the necessary database operations # Perform the necessary database operations
for model in models_to_create: for model in models_to_create:
Model.create(**model.to_db_model().model_dump()) Model.create(**model.model_dump())
for model in models_to_update: for model in models_to_update:
Model.update(**model.to_db_model().model_dump()).where( Model.update(**model.model_dump()).where(
(Model.id == model.id) & (Model.source == model.source) Model.id == model.id
).execute() ).execute()
for model_id, model_source in models_to_delete: for model_id, model_source in models_to_delete:
Model.delete().where( Model.delete().where(Model.id == model_id).execute()
(Model.id == model_id) & (Model.source == model_source)
).execute()
return True return True
except Exception as e: except Exception as e:

View File

@ -37,7 +37,7 @@ import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from apps.web.models.models import Models, ModelModel, ModelForm from apps.web.models.models import Models, ModelModel
from utils.utils import get_admin_user from utils.utils import get_admin_user
from apps.rag.utils import rag_messages from apps.rag.utils import rag_messages
@ -112,7 +112,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.MODEL_CONFIG = [model.to_form() for model in Models.get_all_models()] app.state.MODEL_CONFIG = Models.get_all_models()
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
@ -320,7 +320,7 @@ async def update_model_filter_config(
class SetModelConfigForm(BaseModel): class SetModelConfigForm(BaseModel):
models: List[ModelForm] models: List[ModelModel]
@app.post("/api/config/models") @app.post("/api/config/models")
@ -333,19 +333,10 @@ async def update_model_config(
detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"), detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
) )
ollama_app.state.MODEL_CONFIG = [ ollama_app.state.MODEL_CONFIG = form_data.models
model for model in form_data.models if model.source == "ollama" openai_app.state.MODEL_CONFIG = form_data.models
] litellm_app.state.MODEL_CONFIG = form_data.models
app.state.MODEL_CONFIG = form_data.models
openai_app.state.MODEL_CONFIG = [
model for model in form_data.models if model.source == "openai"
]
litellm_app.state.MODEL_CONFIG = [
model for model in form_data.models if model.source == "litellm"
]
app.state.MODEL_CONFIG = [model for model in form_data.models]
return {"models": app.state.MODEL_CONFIG} return {"models": app.state.MODEL_CONFIG}

View File

@ -227,16 +227,18 @@ export const getModelConfig = async (token: string): Promise<GlobalModelConfig>
export interface ModelConfig { export interface ModelConfig {
id: string; id: string;
name: string; name: string;
source: string; meta: ModelMeta;
base_model?: string; base_model_id?: string;
params: ModelParams; params: ModelParams;
} }
export interface ModelParams { export interface ModelMeta {
description?: string; description?: string;
vision_capable?: boolean; vision_capable?: boolean;
} }
export interface ModelParams {}
export type GlobalModelConfig = ModelConfig[]; export type GlobalModelConfig = ModelConfig[];
export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { export const updateModelConfig = async (token: string, config: GlobalModelConfig) => {

View File

@ -343,7 +343,7 @@
const hasImages = messages.some((message) => const hasImages = messages.some((message) =>
message.files?.some((file) => file.type === 'image') message.files?.some((file) => file.type === 'image')
); );
if (hasImages && !(model.custom_info?.params.vision_capable ?? true)) { if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) {
toast.error( toast.error(
$i18n.t('Model {{modelName}} is not vision capable', { $i18n.t('Model {{modelName}} is not vision capable', {
modelName: model.custom_info?.name ?? model.name ?? model.id modelName: model.custom_info?.name ?? model.name ?? model.id

View File

@ -359,7 +359,7 @@
if (!model) { if (!model) {
continue; continue;
} }
if (model.custom_info?.params.vision_capable ?? true) { if (model.custom_info?.meta.vision_capable ?? true) {
visionCapableCount++; visionCapableCount++;
} }
} }

View File

@ -307,10 +307,10 @@
</div> </div>
</Tooltip> </Tooltip>
{/if} {/if}
{#if item.info?.custom_info?.params.description} {#if item.info?.custom_info?.meta.description}
<Tooltip <Tooltip
content={`${sanitizeResponseContent( content={`${sanitizeResponseContent(
item.info.custom_info?.params.description item.info.custom_info?.meta.description
).replaceAll('\n', '<br>')}`} ).replaceAll('\n', '<br>')}`}
> >
<div class=""> <div class="">

View File

@ -80,8 +80,8 @@
const model = $models.find((m) => m.id === selectedModelId); const model = $models.find((m) => m.id === selectedModelId);
if (model) { if (model) {
modelName = model.custom_info?.name ?? model.name; modelName = model.custom_info?.name ?? model.name;
modelDescription = model.custom_info?.params.description ?? ''; modelDescription = model.custom_info?.meta.description ?? '';
modelIsVisionCapable = model.custom_info?.params.vision_capable ?? false; modelIsVisionCapable = model.custom_info?.meta.vision_capable ?? false;
} }
}; };
@ -518,18 +518,16 @@
if (!model) { if (!model) {
return; return;
} }
const modelSource =
'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
// Remove any existing config // Remove any existing config
modelConfig = modelConfig.filter( modelConfig = modelConfig.filter(
(m) => !(m.id === selectedModelId && m.source === modelSource) (m) => !(m.id === selectedModelId)
); );
// Add new config // Add new config
modelConfig.push({ modelConfig.push({
id: selectedModelId, id: selectedModelId,
name: modelName, name: modelName,
source: modelSource, params: {},
params: { meta: {
description: modelDescription, description: modelDescription,
vision_capable: modelIsVisionCapable vision_capable: modelIsVisionCapable
} }
@ -549,10 +547,8 @@
if (!model) { if (!model) {
return; return;
} }
const modelSource =
'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
modelConfig = modelConfig.filter( modelConfig = modelConfig.filter(
(m) => !(m.id === selectedModelId && m.source === modelSource) (m) => !(m.id === selectedModelId)
); );
await updateModelConfig(localStorage.token, modelConfig); await updateModelConfig(localStorage.token, modelConfig);
toast.success( toast.success(