From 19c98b74fa6d38b7f1c356680d731f2f10160723 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 15 Nov 2024 19:14:24 -0800 Subject: [PATCH] refac: base models endpoint --- backend/open_webui/main.py | 14 ++++++++++++-- src/lib/apis/index.ts | 11 ++++++----- src/lib/components/admin/Settings/Models.svelte | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 3c26c1672..a77639561 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -915,8 +915,7 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION -async def get_all_models(): - # TODO: Optimize this function +async def get_all_base_models(): open_webui_models = [] openai_models = [] ollama_models = [] @@ -942,6 +941,11 @@ async def get_all_models(): open_webui_models = await get_open_webui_models() models = open_webui_models + openai_models + ollama_models + return models + + +async def get_all_models(): + models = await get_all_base_models() # If there are no models, return an empty list if len([model for model in models if model["owned_by"] != "arena"]) == 0: @@ -1084,6 +1088,12 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.get("/api/models/base") +async def get_base_models(user=Depends(get_admin_user)): + models = await get_all_base_models() + return {"data": models} + + @app.post("/api/chat/completions") async def generate_chat_completions( form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 910f150d7..7d7ca0e2d 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,9 +1,10 @@ import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; -export const getModels = async (token: string = '') => { +export const getModels = async (token: string = '', base: boolean = false) => { let error = null; - - const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/models${ + base ? '/base' : '' + }`, { method: 'GET', headers: { Accept: 'application/json', @@ -16,17 +17,17 @@ export const getModels = async (token: string = '') => { return res.json(); }) .catch((err) => { - console.log(err); error = err; + console.log(err); return null; }); + if (error) { throw error; } let models = res?.data ?? []; - models = models .filter((models) => models) // Sort the models diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 0740e5afa..704276044 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -43,7 +43,7 @@ const init = async () => { const workspaceModels = await getBaseModels(localStorage.token); - const allModels = await getModels(localStorage.token); + const allModels = await getModels(localStorage.token, true); models = allModels .filter((m) => !(m?.preset ?? false))