From d5f84d6234e616686ec7b3e573ba3141d0ab80da Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 15 Nov 2024 22:04:33 -0800 Subject: [PATCH] refac: model preset handling behaviour --- .../open_webui/apps/webui/models/models.py | 2 +- backend/open_webui/main.py | 10 +- src/lib/components/admin/Settings.svelte | 2 +- .../components/admin/Settings/Models.svelte | 444 ++++++++++-------- src/lib/components/workspace/Knowledge.svelte | 7 +- .../workspace/Models/ModelEditor.svelte | 34 +- .../workspace/common/AccessControl.svelte | 12 + 7 files changed, 305 insertions(+), 206 deletions(-) diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index f6643aea3..386f56a47 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -238,7 +238,7 @@ class ModelsTable: result = ( db.query(Model) .filter_by(id=id) - .update(model.model_dump(exclude={"id"}, exclude_none=True)) + .update(model.model_dump(exclude={"id"})) ) db.commit() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a77639561..a0311a6cb 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -948,7 +948,7 @@ async def get_all_models(): models = await get_all_base_models() # If there are no models, return an empty list - if len([model for model in models if model["owned_by"] != "arena"]) == 0: + if len([model for model in models if not model.get("arena", False)]) == 0: return [] global_action_ids = [ @@ -975,7 +975,7 @@ async def get_all_models(): action_ids.extend(model["info"]["meta"].get("actionIds", [])) model["action_ids"] = action_ids - else: + elif custom_model.id not in [model["id"] for model in models]: owned_by = "openai" pipe = None action_ids = [] @@ -997,7 +997,7 @@ async def get_all_models(): models.append( { - "id": f"open-webui.{custom_model.id}", + "id": f"{custom_model.id}", "name": custom_model.name, "object": "model", "created": custom_model.created_at, @@ -1164,10 +1164,6 @@ async def generate_chat_completions( "selected_model_id": selected_model_id, } - if model_id.startswith("open-webui."): - model_id = model_id[len("open-webui.") :] - form_data["model"] = model_id - if model.get("pipe"): # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter return await generate_function_chat_completion(form_data, user=user) diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 530f605ed..2cb72f2f7 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -327,7 +327,7 @@ -
+
{#if selectedTab === 'general'} { diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 704276044..ccc59838f 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -20,12 +20,20 @@ import Switch from '$lib/components/common/Switch.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; + import ModelEditor from '$lib/components/workspace/Models/ModelEditor.svelte'; + let importFiles; let modelsImportInputElement: HTMLInputElement; let models = null; + + let workspaceModels = null; + let baseModels = null; + let filteredModels = []; + let selectedModelId = null; + $: if (models) { filteredModels = models.filter( (m) => searchValue === '' || m.name.toLowerCase().includes(searchValue.toLowerCase()) @@ -42,23 +50,59 @@ }; const init = async () => { - const workspaceModels = await getBaseModels(localStorage.token); - const allModels = await getModels(localStorage.token, true); + workspaceModels = await getBaseModels(localStorage.token); + baseModels = await getModels(localStorage.token, true); - models = allModels - .filter((m) => !(m?.preset ?? false)) - .map((m) => { - const workspaceModel = workspaceModels.find((wm) => wm.id === m.id); + models = baseModels.map((m) => { + const workspaceModel = workspaceModels.find((wm) => wm.id === m.id); - if (workspaceModel) { - return workspaceModel; - } else { - return { - ...m, - is_active: true - }; - } + if (workspaceModel) { + return workspaceModel; + } else { + return { + id: m.id, + name: m.name, + is_active: true + }; + } + }); + }; + + const upsertModelHandler = async (model) => { + model.base_model_id = null; + + if (models.find((m) => m.id === model.id)) { + await updateModelById(localStorage.token, model.id, model).catch((error) => { + return null; }); + } else { + await createNewModel(localStorage.token, model).catch((error) => { + return null; + }); + } + + await init(); + }; + + const toggleModelHandler = async (model) => { + if (!Object.keys(model).includes('base_model_id')) { + await createNewModel(localStorage.token, { + id: model.id, + name: model.name, + base_model_id: null, + meta: {}, + params: {}, + is_active: model.is_active + }).catch((error) => { + return null; + }); + + await init(); + } else { + await toggleModelById(localStorage.token, model.id); + } + + _models.set(await getModels(localStorage.token)); }; onMount(async () => { @@ -67,201 +111,229 @@ {#if models !== null} -
-
-
- {$i18n.t('Models')} -
- {filteredModels.length} -
-
- -
-
-
- + {#if selectedModelId === null} +
+
+
+ {$i18n.t('Models')} +
+ {filteredModels.length} +
+
+ +
+
+
+ +
+
-
-
-
- {#each filteredModels as model (model.id)} -
- -
-
- modelfile profile -
-
- -
- -
{model.name}
-
-
- {model?.meta?.description ?? model.id} -
-
-
-
- {#if $user?.role === 'admin' || model.user_id === $user?.id} - + {#if models.length > 0} + {#each filteredModels as model (model.id)} +
+ +
+ + +
+ + { + toggleModelHandler(model); + }} + /> + +
+
+
+ {/each} + {:else} +
+
+ {$i18n.t('No models found')}
-
- {/each} -
+ {/if} +
- {#if $user?.role === 'admin'} -
-
- { - console.log(importFiles); + {#if $user?.role === 'admin'} +
+
+ { + console.log(importFiles); - let reader = new FileReader(); - reader.onload = async (event) => { - let savedModels = JSON.parse(event.target.result); - console.log(savedModels); + let reader = new FileReader(); + reader.onload = async (event) => { + let savedModels = JSON.parse(event.target.result); + console.log(savedModels); - for (const model of savedModels) { - if (model?.info ?? false) { - if ($_models.find((m) => m.id === model.id)) { - await updateModelById(localStorage.token, model.id, model.info).catch( - (error) => { - return null; - } - ); + for (const model of savedModels) { + if (Object.keys(model).includes('base_model_id')) { + if (model.base_model_id === null) { + upsertModelHandler(model); + } } else { - await createNewModel(localStorage.token, model.info).catch((error) => { - return null; - }); + if (model?.info ?? false) { + if (model.info.base_model_id === null) { + upsertModelHandler(model.info); + } + } } } - } - await _models.set(await getModels(localStorage.token)); - init(); - }; + await _models.set(await getModels(localStorage.token)); + init(); + }; - reader.readAsText(importFiles[0]); - }} - /> + reader.readAsText(importFiles[0]); + }} + /> - +
+ + + +
+ - +
+ + + +
+ +
-
+ {/if} + {:else} + m.id === selectedModelId)} + preset={false} + onSubmit={(model) => { + console.log(model); + upsertModelHandler(model); + selectedModelId = null; + }} + onBack={() => { + selectedModelId = null; + }} + /> {/if} {:else}
diff --git a/src/lib/components/workspace/Knowledge.svelte b/src/lib/components/workspace/Knowledge.svelte index 7945aee65..7b0d50658 100644 --- a/src/lib/components/workspace/Knowledge.svelte +++ b/src/lib/components/workspace/Knowledge.svelte @@ -10,15 +10,10 @@ const i18n = getContext('i18n'); import { WEBUI_NAME, knowledge } from '$lib/stores'; - import { getKnowledgeItems, deleteKnowledgeById } from '$lib/apis/knowledge'; - import { blobToFile, transformFileName } from '$lib/utils'; - import { goto } from '$app/navigation'; - import Tooltip from '../common/Tooltip.svelte'; - import GarbageBin from '../icons/GarbageBin.svelte'; - import Pencil from '../icons/Pencil.svelte'; + import DeleteConfirmDialog from '../common/ConfirmDialog.svelte'; import ItemMenu from './Knowledge/ItemMenu.svelte'; import Badge from '../common/Badge.svelte'; diff --git a/src/lib/components/workspace/Models/ModelEditor.svelte b/src/lib/components/workspace/Models/ModelEditor.svelte index baff78e40..a7039691a 100644 --- a/src/lib/components/workspace/Models/ModelEditor.svelte +++ b/src/lib/components/workspace/Models/ModelEditor.svelte @@ -1,8 +1,4 @@ {#if loaded} + {#if onBack} + + {/if} +
- {#if !edit || model} + {#if !edit || (edit && model)}
{ diff --git a/src/lib/components/workspace/common/AccessControl.svelte b/src/lib/components/workspace/common/AccessControl.svelte index 4f323dbe2..88ee49f3a 100644 --- a/src/lib/components/workspace/common/AccessControl.svelte +++ b/src/lib/components/workspace/common/AccessControl.svelte @@ -1,4 +1,6 @@