From f21c8626d6598ea2e67fd74b042b55a4c81d3ca5 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Tue, 21 May 2024 22:05:16 +0100 Subject: [PATCH] refac: switch to meta and params, remove source --- backend/apps/litellm/main.py | 4 +- backend/apps/ollama/main.py | 4 +- backend/apps/openai/main.py | 4 +- backend/apps/web/internal/db.py | 12 +++ .../{008_add_models.py => 009_add_models.py} | 13 +-- backend/apps/web/models/models.py | 85 +++++++------------ backend/main.py | 23 ++--- src/lib/apis/index.ts | 8 +- src/lib/components/chat/Chat.svelte | 2 +- src/lib/components/chat/MessageInput.svelte | 2 +- .../chat/ModelSelector/Selector.svelte | 4 +- .../components/chat/Settings/Models.svelte | 16 ++-- 12 files changed, 70 insertions(+), 107 deletions(-) rename backend/apps/web/internal/migrations/{008_add_models.py => 009_add_models.py} (86%) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 9286c36d5..bef91443a 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -78,9 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file: app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value -app.state.MODEL_CONFIG = [ - model.to_form() for model in Models.get_all_models_by_source("litellm") -] +app.state.MODEL_CONFIG = Models.get_all_models() app.state.ENABLE = ENABLE_LITELLM app.state.CONFIG = litellm_config diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index a310a25e0..6c85bf819 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -66,9 +66,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.MODEL_CONFIG = [ - model.to_form() for model in Models.get_all_models_by_source("ollama") -] +app.state.MODEL_CONFIG = Models.get_all_models() app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index c5f35d315..678e64746 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -52,9 +52,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.MODEL_CONFIG = [ - model.to_form() for model in Models.get_all_models_by_source("openai") -] +app.state.MODEL_CONFIG = Models.get_all_models() app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py index 136e3fafc..398f5a231 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/web/internal/db.py @@ -1,3 +1,5 @@ +import json + from peewee import * from peewee_migrate import Router from playhouse.db_url import connect @@ -8,6 +10,16 @@ import logging log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) + +class JSONField(TextField): + def db_value(self, value): + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file diff --git a/backend/apps/web/internal/migrations/008_add_models.py b/backend/apps/web/internal/migrations/009_add_models.py similarity index 86% rename from backend/apps/web/internal/migrations/008_add_models.py rename to backend/apps/web/internal/migrations/009_add_models.py index 982927113..276769441 100644 --- a/backend/apps/web/internal/migrations/008_add_models.py +++ b/backend/apps/web/internal/migrations/009_add_models.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 008_add_models.py. +"""Peewee migrations -- 009_add_models.py. Some examples (model - class or model name):: @@ -39,20 +39,15 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): @migrator.create_model class Model(pw.Model): - id = pw.TextField() - source = pw.TextField() - base_model = pw.TextField(null=True) + id = pw.TextField(unique=True) + meta = pw.TextField() + base_model_id = pw.TextField(null=True) name = pw.TextField() params = pw.TextField() class Meta: table_name = "model" - indexes = ( - # Create a unique index on the id, source columns - (("id", "source"), True), - ) - def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" diff --git a/backend/apps/web/models/models.py b/backend/apps/web/models/models.py index d1cb082e8..cd734e67b 100644 --- a/backend/apps/web/models/models.py +++ b/backend/apps/web/models/models.py @@ -6,7 +6,7 @@ import peewee as pw from playhouse.shortcuts import model_to_dict from pydantic import BaseModel -from apps.web.internal.db import DB +from apps.web.internal.db import DB, JSONField from config import SRC_LOG_LEVELS @@ -22,6 +22,12 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) # ModelParams is a model for the data stored in the params field of the Model table # It isn't currently used in the backend, but it's here as a reference class ModelParams(BaseModel): + pass + + +# ModelMeta is a model for the data stored in the meta field of the Model table +# It isn't currently used in the backend, but it's here as a reference +class ModelMeta(BaseModel): description: str """ User-facing description of the model. @@ -34,50 +40,42 @@ class ModelParams(BaseModel): class Model(pw.Model): - id = pw.TextField() + id = pw.TextField(unique=True) """ The model's id as used in the API. If set to an existing model, it will override the model. """ - source = pw.TextField() + meta = JSONField() """ - The source of the model, e.g., ollama, openai, or litellm. + Holds a JSON encoded blob of metadata, see `ModelMeta`. """ - base_model = pw.TextField(null=True) + base_model_id = pw.TextField(null=True) """ - An optional pointer to the actual model that should be used when proxying requests. - Currently unused - but will be used to support Modelfile like behaviour in the future + An optional pointer to the actual model that should be used when proxying requests. + Currently unused - but will be used to support Modelfile like behaviour in the future """ name = pw.TextField() """ - The human-readable display name of the model. + The human-readable display name of the model. """ - params = pw.TextField() + params = JSONField() """ - Holds a JSON encoded blob of parameters, see `ModelParams`. + Holds a JSON encoded blob of parameters, see `ModelParams`. """ class Meta: database = DB - indexes = ( - # Create a unique index on the id, source columns - (("id", "source"), True), - ) - class ModelModel(BaseModel): id: str - source: str - base_model: Optional[str] = None + meta: ModelMeta + base_model_id: Optional[str] = None name: str - params: str - - def to_form(self) -> "ModelForm": - return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)}) + params: ModelParams #################### @@ -85,17 +83,6 @@ class ModelModel(BaseModel): #################### -class ModelForm(BaseModel): - id: str - source: str - base_model: Optional[str] = None - name: str - params: dict - - def to_db_model(self) -> ModelModel: - return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)}) - - class ModelsTable: def __init__( @@ -108,51 +95,37 @@ class ModelsTable: def get_all_models(self) -> list[ModelModel]: return [ModelModel(**model_to_dict(model)) for model in Model.select()] - def get_all_models_by_source(self, source: str) -> list[ModelModel]: - return [ - ModelModel(**model_to_dict(model)) - for model in Model.select().where(Model.source == source) - ] - - def update_all_models(self, models: list[ModelForm]) -> bool: + def update_all_models(self, models: list[ModelModel]) -> bool: try: with self.db.atomic(): # Fetch current models from the database current_models = self.get_all_models() - current_model_dict = { - (model.id, model.source): model for model in current_models - } + current_model_dict = {model.id: model for model in current_models} - # Create a set of model IDs and sources from the current models and the new models + # Create a set of model IDs from the current models and the new models current_model_keys = set(current_model_dict.keys()) - new_model_keys = set((model.id, model.source) for model in models) + new_model_keys = set(model.id for model in models) # Determine which models need to be created, updated, or deleted models_to_create = [ - model - for model in models - if (model.id, model.source) not in current_model_keys + model for model in models if model.id not in current_model_keys ] models_to_update = [ - model - for model in models - if (model.id, model.source) in current_model_keys + model for model in models if model.id in current_model_keys ] models_to_delete = current_model_keys - new_model_keys # Perform the necessary database operations for model in models_to_create: - Model.create(**model.to_db_model().model_dump()) + Model.create(**model.model_dump()) for model in models_to_update: - Model.update(**model.to_db_model().model_dump()).where( - (Model.id == model.id) & (Model.source == model.source) + Model.update(**model.model_dump()).where( + Model.id == model.id ).execute() for model_id, model_source in models_to_delete: - Model.delete().where( - (Model.id == model_id) & (Model.source == model_source) - ).execute() + Model.delete().where(Model.id == model_id).execute() return True except Exception as e: diff --git a/backend/main.py b/backend/main.py index 4ff77a7b6..e86c1dbb5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -37,7 +37,7 @@ import asyncio from pydantic import BaseModel from typing import List, Optional -from apps.web.models.models import Models, ModelModel, ModelForm +from apps.web.models.models import Models, ModelModel from utils.utils import get_admin_user from apps.rag.utils import rag_messages @@ -112,7 +112,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.MODEL_CONFIG = [model.to_form() for model in Models.get_all_models()] +app.state.MODEL_CONFIG = Models.get_all_models() app.state.config.WEBHOOK_URL = WEBHOOK_URL @@ -320,7 +320,7 @@ async def update_model_filter_config( class SetModelConfigForm(BaseModel): - models: List[ModelForm] + models: List[ModelModel] @app.post("/api/config/models") @@ -333,19 +333,10 @@ async def update_model_config( detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"), ) - ollama_app.state.MODEL_CONFIG = [ - model for model in form_data.models if model.source == "ollama" - ] - - openai_app.state.MODEL_CONFIG = [ - model for model in form_data.models if model.source == "openai" - ] - - litellm_app.state.MODEL_CONFIG = [ - model for model in form_data.models if model.source == "litellm" - ] - - app.state.MODEL_CONFIG = [model for model in form_data.models] + ollama_app.state.MODEL_CONFIG = form_data.models + openai_app.state.MODEL_CONFIG = form_data.models + litellm_app.state.MODEL_CONFIG = form_data.models + app.state.MODEL_CONFIG = form_data.models return {"models": app.state.MODEL_CONFIG} diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index d64ffb658..a7b59a7ca 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -227,16 +227,18 @@ export const getModelConfig = async (token: string): Promise export interface ModelConfig { id: string; name: string; - source: string; - base_model?: string; + meta: ModelMeta; + base_model_id?: string; params: ModelParams; } -export interface ModelParams { +export interface ModelMeta { description?: string; vision_capable?: boolean; } +export interface ModelParams {} + export type GlobalModelConfig = ModelConfig[]; export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 2465b53cd..abd75ea39 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -343,7 +343,7 @@ const hasImages = messages.some((message) => message.files?.some((file) => file.type === 'image') ); - if (hasImages && !(model.custom_info?.params.vision_capable ?? true)) { + if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) { toast.error( $i18n.t('Model {{modelName}} is not vision capable', { modelName: model.custom_info?.name ?? model.name ?? model.id diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index a1ebe9352..3f7250c4a 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -359,7 +359,7 @@ if (!model) { continue; } - if (model.custom_info?.params.vision_capable ?? true) { + if (model.custom_info?.meta.vision_capable ?? true) { visionCapableCount++; } } diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index bf5ef4c93..503f950e2 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -307,10 +307,10 @@ {/if} - {#if item.info?.custom_info?.params.description} + {#if item.info?.custom_info?.meta.description} ')}`} >
diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index 43f74418f..58ebaaf55 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -80,8 +80,8 @@ const model = $models.find((m) => m.id === selectedModelId); if (model) { modelName = model.custom_info?.name ?? model.name; - modelDescription = model.custom_info?.params.description ?? ''; - modelIsVisionCapable = model.custom_info?.params.vision_capable ?? false; + modelDescription = model.custom_info?.meta.description ?? ''; + modelIsVisionCapable = model.custom_info?.meta.vision_capable ?? false; } }; @@ -518,18 +518,16 @@ if (!model) { return; } - const modelSource = - 'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai'; // Remove any existing config modelConfig = modelConfig.filter( - (m) => !(m.id === selectedModelId && m.source === modelSource) + (m) => !(m.id === selectedModelId) ); // Add new config modelConfig.push({ id: selectedModelId, name: modelName, - source: modelSource, - params: { + params: {}, + meta: { description: modelDescription, vision_capable: modelIsVisionCapable } @@ -549,10 +547,8 @@ if (!model) { return; } - const modelSource = - 'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai'; modelConfig = modelConfig.filter( - (m) => !(m.id === selectedModelId && m.source === modelSource) + (m) => !(m.id === selectedModelId) ); await updateModelConfig(localStorage.token, modelConfig); toast.success(