diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index 6434cfb16..f6643aea3 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -188,6 +188,13 @@ class ModelsTable: for model in db.query(Model).filter(Model.base_model_id != None).all() ] + def get_base_models(self) -> list[ModelModel]: + with get_db() as db: + return [ + ModelModel.model_validate(model) + for model in db.query(Model).filter(Model.base_model_id == None).all() + ] + def get_models_by_user_id( self, user_id: str, permission: str = "write" ) -> list[ModelModel]: diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index 7ba8d8190..ef8413b20 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -28,6 +28,16 @@ async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): return Models.get_models_by_user_id(user.id) +########################### +# GetBaseModels +########################### + + +@router.get("/base", response_model=list[ModelResponse]) +async def get_base_models(user=Depends(get_admin_user)): + return Models.get_base_models() + + ############################ # CreateNewModel ############################ diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index aa3e0437a..3c26c1672 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -993,7 +993,7 @@ async def get_all_models(): models.append( { - "id": f"open-webui-{custom_model.id}", + "id": f"open-webui.{custom_model.id}", "name": custom_model.name, "object": "model", "created": custom_model.created_at, @@ -1154,8 +1154,8 @@ async def generate_chat_completions( "selected_model_id": selected_model_id, } - if model_id.startswith("open-webui-"): - model_id = model_id[len("open-webui-") :] + if model_id.startswith("open-webui."): + model_id = model_id[len("open-webui.") :] form_data["model"] = model_id if model.get("pipe"): diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 40d0e0392..910f150d7 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -31,21 +31,6 @@ export const getModels = async (token: string = '') => { .filter((models) => models) // Sort the models .sort((a, b) => { - // Check if models have position property - const aHasPosition = a.info?.meta?.position !== undefined; - const bHasPosition = b.info?.meta?.position !== undefined; - - // If both a and b have the position property - if (aHasPosition && bHasPosition) { - return a.info.meta.position - b.info.meta.position; - } - - // If only a has the position property, it should come first - if (aHasPosition) return -1; - - // If only b has the position property, it should come first - if (bHasPosition) return 1; - // Compare case-insensitively by name for models without position property const lowerA = a.name.toLowerCase(); const lowerB = b.name.toLowerCase(); diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 86aeb2d89..90ec3286d 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,7 +1,7 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const getWorkspaceModels = async (token: string = '') => { +export const getModels = async (token: string = '') => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/models`, { @@ -34,6 +34,39 @@ export const getWorkspaceModels = async (token: string = '') => { +export const getBaseModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/base`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + + export const createNewModel = async (token: string, model: object) => { let error = null; diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index a19635379..6934157fa 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -1,8 +1,259 @@ -
-
Models
+
+
+
+ {$i18n.t('Models')} +
+ {filteredModels.length} +
+
+ +
+
+
+ +
+ +
+
+ +
+ {#each filteredModels as model (model.id)} +
+ +
+
+ modelfile profile +
+
+ +
+ +
{model.name}
+
+
+ {model?.meta?.description ?? model.id} +
+
+
+
+ {#if $user?.role === 'admin' || model.user_id === $user?.id} + + + + + + {/if} + +
+ + { + toggleModelById(localStorage.token, model.id); + _models.set(await getModels(localStorage.token)); + }} + /> + +
+
+
+ {/each} +
+ +{#if $user?.role === 'admin'} +
+
+ { + console.log(importFiles); + + let reader = new FileReader(); + reader.onload = async (event) => { + let savedModels = JSON.parse(event.target.result); + console.log(savedModels); + + for (const model of savedModels) { + if (model?.info ?? false) { + if ($_models.find((m) => m.id === model.id)) { + await updateModelById(localStorage.token, model.id, model.info).catch((error) => { + return null; + }); + } else { + await createNewModel(localStorage.token, model.info).catch((error) => { + return null; + }); + } + } + } + + await _models.set(await getModels(localStorage.token)); + init(); + }; + + reader.readAsText(importFiles[0]); + }} + /> + + + + +
+
+{/if} diff --git a/src/lib/components/workspace/Models.svelte b/src/lib/components/workspace/Models.svelte index ac50f73bd..6e9cf1f4c 100644 --- a/src/lib/components/workspace/Models.svelte +++ b/src/lib/components/workspace/Models.svelte @@ -15,7 +15,7 @@ import { createNewModel, deleteModelById, - getWorkspaceModels, + getModels as getWorkspaceModels, toggleModelById, updateModelById } from '$lib/apis/models'; @@ -29,7 +29,6 @@ import GarbageBin from '../icons/GarbageBin.svelte'; import Search from '../icons/Search.svelte'; import Plus from '../icons/Plus.svelte'; - import { get } from 'svelte/store'; import ChevronRight from '../icons/ChevronRight.svelte'; import Switch from '../common/Switch.svelte';