diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 2b6966381..2cda11eb5 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -6,7 +6,7 @@ from apps.web.routers import ( users, chats, documents, - modelfiles, + models, prompts, configs, memories, @@ -56,11 +56,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) -app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) - app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py index 1d60d7c55..fe278ed5f 100644 --- a/backend/apps/web/models/modelfiles.py +++ b/backend/apps/web/models/modelfiles.py @@ -1,3 +1,11 @@ +################################################################################ +# DEPRECATION NOTICE # +# # +# This file has been deprecated since version 0.2.0. # +# # +################################################################################ + + from pydantic import BaseModel from peewee import * from playhouse.shortcuts import model_to_dict diff --git a/backend/apps/web/models/models.py b/backend/apps/web/models/models.py index ae8d618fa..94e349cbf 100644 --- a/backend/apps/web/models/models.py +++ b/backend/apps/web/models/models.py @@ -3,13 +3,18 @@ import logging from typing import Optional import peewee as pw +from peewee import * + from playhouse.shortcuts import model_to_dict from pydantic import BaseModel, ConfigDict from apps.web.internal.db import DB, JSONField +from typing import List, Union, Optional from config import SRC_LOG_LEVELS +import time + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -20,10 +25,8 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) # ModelParams is a model for the data stored in the params field of the Model table -# It isn't currently used in the backend, but it's here as a reference class ModelParams(BaseModel): model_config = ConfigDict(extra="allow") - pass @@ -55,7 +58,6 @@ class Model(pw.Model): base_model_id = pw.TextField(null=True) """ An optional pointer to the actual model that should be used when proxying requests. - Currently unused - but will be used to support Modelfile like behaviour in the future """ name = pw.TextField() @@ -73,8 +75,8 @@ class Model(pw.Model): Holds a JSON encoded blob of metadata, see `ModelMeta`. """ - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch + updated_at = BigIntegerField() + created_at = BigIntegerField() class Meta: database = DB @@ -83,16 +85,36 @@ class Model(pw.Model): class ModelModel(BaseModel): id: str base_model_id: Optional[str] = None + name: str params: ModelParams meta: ModelMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + #################### # Forms #################### +class ModelResponse(BaseModel): + id: str + name: str + meta: ModelMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ModelForm(BaseModel): + id: str + base_model_id: Optional[str] = None + name: str + meta: ModelMeta + params: ModelParams + + class ModelsTable: def __init__( self, @@ -101,44 +123,47 @@ class ModelsTable: self.db = db self.db.create_tables([Model]) - def get_all_models(self) -> list[ModelModel]: + def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]: + try: + model = Model.create( + **{ + **model.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + return ModelModel(**model_to_dict(model)) + except: + return None + + def get_all_models(self) -> List[ModelModel]: return [ModelModel(**model_to_dict(model)) for model in Model.select()] - def update_all_models(self, models: list[ModelModel]) -> bool: + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - with self.db.atomic(): - # Fetch current models from the database - current_models = self.get_all_models() - current_model_dict = {model.id: model for model in current_models} + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except: + return None - # Create a set of model IDs from the current models and the new models - current_model_keys = set(current_model_dict.keys()) - new_model_keys = set(model.id for model in models) + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + try: + # update only the fields that are present in the model + query = Model.update(**model.model_dump()).where(Model.id == id) + query.execute() - # Determine which models need to be created, updated, or deleted - models_to_create = [ - model for model in models if model.id not in current_model_keys - ] - models_to_update = [ - model for model in models if model.id in current_model_keys - ] - models_to_delete = current_model_keys - new_model_keys - - # Perform the necessary database operations - for model in models_to_create: - Model.create(**model.model_dump()) - - for model in models_to_update: - Model.update(**model.model_dump()).where( - Model.id == model.id - ).execute() - - for model_id, model_source in models_to_delete: - Model.delete().where(Model.id == model_id).execute() + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except: + return None + def delete_model_by_id(self, id: str) -> bool: + try: + query = Model.delete().where(Model.id == id) + query.execute() return True - except Exception as e: - log.exception(e) + except: return False diff --git a/backend/apps/web/routers/modelfiles.py b/backend/apps/web/routers/modelfiles.py deleted file mode 100644 index 3cdbf8a74..000000000 --- a/backend/apps/web/routers/modelfiles.py +++ /dev/null @@ -1,124 +0,0 @@ -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel -import json -from apps.web.models.modelfiles import ( - Modelfiles, - ModelfileForm, - ModelfileTagNameForm, - ModelfileUpdateForm, - ModelfileResponse, -) - -from utils.utils import get_current_user, get_admin_user -from constants import ERROR_MESSAGES - -router = APIRouter() - -############################ -# GetModelfiles -############################ - - -@router.get("/", response_model=List[ModelfileResponse]) -async def get_modelfiles( - skip: int = 0, limit: int = 50, user=Depends(get_current_user) -): - return Modelfiles.get_modelfiles(skip, limit) - - -############################ -# CreateNewModelfile -############################ - - -@router.post("/create", response_model=Optional[ModelfileResponse]) -async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)): - modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -############################ -# GetModelfileByTagName -############################ - - -@router.post("/", response_model=Optional[ModelfileResponse]) -async def get_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_current_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# UpdateModelfileByTagName -############################ - - -@router.post("/update", response_model=Optional[ModelfileResponse]) -async def update_modelfile_by_tag_name( - form_data: ModelfileUpdateForm, user=Depends(get_admin_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - if modelfile: - updated_modelfile = { - **json.loads(modelfile.modelfile), - **form_data.modelfile, - } - - modelfile = Modelfiles.update_modelfile_by_tag_name( - form_data.tag_name, updated_modelfile - ) - - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - - -############################ -# DeleteModelfileByTagName -############################ - - -@router.delete("/delete", response_model=bool) -async def delete_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_admin_user) -): - result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) - return result diff --git a/backend/apps/web/routers/models.py b/backend/apps/web/routers/models.py new file mode 100644 index 000000000..696e359d9 --- /dev/null +++ b/backend/apps/web/routers/models.py @@ -0,0 +1,89 @@ +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json +from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse + +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +router = APIRouter() + +########################### +# getAllModels +########################### + + +@router.get("/", response_model=List[ModelResponse]) +async def get_models(user=Depends(get_verified_user)): + return Models.get_all_models() + + +############################ +# AddNewModel +############################ + + +@router.post("/add", response_model=Optional[ModelModel]) +async def add_new_model(form_data: ModelForm, user=Depends(get_admin_user)): + model = Models.insert_new_model(form_data, user.id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# GetModelById +############################ + + +@router.get("/{id}", response_model=Optional[ModelModel]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateModelById +############################ + + +@router.post("/{id}/update", response_model=Optional[ModelModel]) +async def update_model_by_id( + id: str, form_data: ModelForm, user=Depends(get_admin_user) +): + model = Models.get_model_by_id(id) + if model: + model = Models.update_model_by_id(id, form_data) + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# DeleteModelById +############################ + + +@router.delete("/{id}/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_admin_user)): + result = Models.delete_model_by_id(id) + return result diff --git a/backend/main.py b/backend/main.py index 7e505c7cd..e19ab57fa 100644 --- a/backend/main.py +++ b/backend/main.py @@ -320,33 +320,6 @@ async def update_model_filter_config( } -class SetModelConfigForm(BaseModel): - models: List[ModelModel] - - -@app.post("/api/config/models") -async def update_model_config( - form_data: SetModelConfigForm, user=Depends(get_admin_user) -): - if not Models.update_all_models(form_data.models): - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"), - ) - - ollama_app.state.MODEL_CONFIG = form_data.models - openai_app.state.MODEL_CONFIG = form_data.models - litellm_app.state.MODEL_CONFIG = form_data.models - app.state.MODEL_CONFIG = form_data.models - - return {"models": app.state.MODEL_CONFIG} - - -@app.get("/api/config/models") -async def get_model_config(user=Depends(get_admin_user)): - return {"models": app.state.MODEL_CONFIG} - - @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { diff --git a/src/lib/apis/modelfiles/index.ts b/src/lib/apis/models/index.ts similarity index 66% rename from src/lib/apis/modelfiles/index.ts rename to src/lib/apis/models/index.ts index 91af5e381..56c299fb2 100644 --- a/src/lib/apis/modelfiles/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,18 +1,16 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewModelfile = async (token: string, modelfile: object) => { +export const addNewModel = async (token: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` }, - body: JSON.stringify({ - modelfile: modelfile - }) + body: JSON.stringify(model) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => { return res; }; -export const getModelfiles = async (token: string = '') => { +export const getModels = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, { method: 'GET', headers: { Accept: 'application/json', @@ -59,62 +57,19 @@ export const getModelfiles = async (token: string = '') => { throw error; } - return res.map((modelfile) => modelfile.modelfile); + return res; }; -export const getModelfileByTagName = async (token: string, tagName: string) => { +export const getModelById = async (token: string, id: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { - method: 'POST', + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, { + method: 'GET', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res.modelfile; -}; - -export const updateModelfileByTagName = async ( - token: string, - tagName: string, - modelfile: object -) => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName, - modelfile: modelfile - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -137,19 +92,49 @@ export const updateModelfileByTagName = async ( return res; }; -export const deleteModelfileByTagName = async (token: string, tagName: string) => { +export const updateModelById = async (token: string, id: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(model) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteModelById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/delete`, { method: 'DELETE', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index 503f950e2..e72eb7e10 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -14,7 +14,7 @@ import { toast } from 'svelte-sonner'; import { capitalizeFirstLetter, - getModels, + getAllModels, sanitizeResponseContent, splitStream } from '$lib/utils'; @@ -159,7 +159,7 @@ }) ); - models.set(await getModels(localStorage.token)); + models.set(await getAllModels(localStorage.token)); } else { toast.error($i18n.t('Download canceled')); } diff --git a/src/lib/components/chat/Settings/Connections.svelte b/src/lib/components/chat/Settings/Connections.svelte index b9978e129..f928e0f44 100644 --- a/src/lib/components/chat/Settings/Connections.svelte +++ b/src/lib/components/chat/Settings/Connections.svelte @@ -23,7 +23,7 @@ const i18n = getContext('i18n'); - export let getModels: Function; + export let getAllModels: Function; // External let OLLAMA_BASE_URLS = ['']; @@ -38,7 +38,7 @@ OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS); OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS); - await models.set(await getModels()); + await models.set(await getAllModels()); }; const updateOllamaUrlsHandler = async () => { @@ -51,7 +51,7 @@ if (ollamaVersion) { toast.success($i18n.t('Server connection verified')); - await models.set(await getModels()); + await models.set(await getAllModels()); } }; diff --git a/src/lib/components/chat/Settings/General.svelte b/src/lib/components/chat/Settings/General.svelte index a30cc1896..060eb8b20 100644 --- a/src/lib/components/chat/Settings/General.svelte +++ b/src/lib/components/chat/Settings/General.svelte @@ -11,7 +11,7 @@ import AdvancedParams from './Advanced/AdvancedParams.svelte'; export let saveSettings: Function; - export let getModels: Function; + export let getAllModels: Function; // General let themes = ['dark', 'light', 'rose-pine dark', 'rose-pine-dawn light', 'oled-dark']; diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index 19e050ca8..ff117f6ae 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -42,7 +42,7 @@ let imageSize = ''; let steps = 50; - const getModels = async () => { + const getAllModels = async () => { models = await getImageGenerationModels(localStorage.token).catch((error) => { toast.error(error); return null; @@ -66,7 +66,7 @@ if (res) { COMFYUI_BASE_URL = res.COMFYUI_BASE_URL; - await getModels(); + await getAllModels(); if (models) { toast.success($i18n.t('Server connection verified')); @@ -85,7 +85,7 @@ if (res) { AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL; - await getModels(); + await getAllModels(); if (models) { toast.success($i18n.t('Server connection verified')); @@ -112,7 +112,7 @@ if (enableImageGeneration) { config.set(await getBackendConfig(localStorage.token)); - getModels(); + getAllModels(); } }; @@ -141,7 +141,7 @@ steps = await getImageSteps(localStorage.token); if (enableImageGeneration) { - getModels(); + getAllModels(); } } }); diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index 58ebaaf55..8f5cce76f 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -22,7 +22,7 @@ const i18n = getContext('i18n'); - export let getModels: Function; + export let getAllModels: Function; let showLiteLLM = false; let showLiteLLMParams = false; @@ -261,7 +261,7 @@ }) ); - models.set(await getModels(localStorage.token)); + models.set(await getAllModels(localStorage.token)); } else { toast.error($i18n.t('Download canceled')); } @@ -424,7 +424,7 @@ modelTransferring = false; uploadProgress = null; - models.set(await getModels()); + models.set(await getAllModels()); }; const deleteModelHandler = async () => { @@ -439,7 +439,7 @@ } deleteModelTag = ''; - models.set(await getModels()); + models.set(await getAllModels()); }; const cancelModelPullHandler = async (model: string) => { @@ -488,7 +488,7 @@ liteLLMMaxTokens = ''; liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); - models.set(await getModels()); + models.set(await getAllModels()); }; const deleteLiteLLMModelHandler = async () => { @@ -507,7 +507,7 @@ deleteLiteLLMModelName = ''; liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); - models.set(await getModels()); + models.set(await getAllModels()); }; const addModelInfoHandler = async () => { @@ -519,9 +519,7 @@ return; } // Remove any existing config - modelConfig = modelConfig.filter( - (m) => !(m.id === selectedModelId) - ); + modelConfig = modelConfig.filter((m) => !(m.id === selectedModelId)); // Add new config modelConfig.push({ id: selectedModelId, @@ -536,7 +534,7 @@ toast.success( $i18n.t('Model info for {{modelName}} added successfully', { modelName: selectedModelId }) ); - models.set(await getModels()); + models.set(await getAllModels()); }; const deleteModelInfoHandler = async () => { @@ -547,14 +545,12 @@ if (!model) { return; } - modelConfig = modelConfig.filter( - (m) => !(m.id === selectedModelId) - ); + modelConfig = modelConfig.filter((m) => !(m.id === selectedModelId)); await updateModelConfig(localStorage.token, modelConfig); toast.success( $i18n.t('Model info for {{modelName}} deleted successfully', { modelName: selectedModelId }) ); - models.set(await getModels()); + models.set(await getAllModels()); }; const toggleIsVisionCapable = () => { diff --git a/src/lib/components/chat/SettingsModal.svelte b/src/lib/components/chat/SettingsModal.svelte index 08207f604..e2782441c 100644 --- a/src/lib/components/chat/SettingsModal.svelte +++ b/src/lib/components/chat/SettingsModal.svelte @@ -3,7 +3,7 @@ import { toast } from 'svelte-sonner'; import { models, settings, user } from '$lib/stores'; - import { getModels as _getModels } from '$lib/utils'; + import { getAllModels as _getAllModels } from '$lib/utils'; import Modal from '../common/Modal.svelte'; import Account from './Settings/Account.svelte'; @@ -25,12 +25,12 @@ const saveSettings = async (updated) => { console.log(updated); await settings.set({ ...$settings, ...updated }); - await models.set(await getModels()); + await models.set(await getAllModels()); localStorage.setItem('settings', JSON.stringify($settings)); }; - const getModels = async () => { - return await _getModels(localStorage.token); + const getAllModels = async () => { + return await _getAllModels(localStorage.token); }; let selectedTab = 'general'; @@ -318,17 +318,17 @@