From 468c6398cd4162a5eb61a31efdfc2e3d68c6c213 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 24 May 2024 03:02:56 -0700 Subject: [PATCH] feat: unified models integration --- backend/apps/openai/main.py | 4 +- backend/main.py | 2 - src/lib/apis/index.ts | 2 +- src/lib/components/chat/Chat.svelte | 34 ++------- src/lib/components/chat/Messages.svelte | 10 +-- .../chat/Messages/Placeholder.svelte | 70 +++++++++---------- src/lib/components/chat/Messages/test.json | 28 ++++++++ src/lib/components/chat/ModelSelector.svelte | 12 ++-- .../chat/ModelSelector/Selector.svelte | 24 ++++--- 9 files changed, 94 insertions(+), 92 deletions(-) create mode 100644 src/lib/components/chat/Messages/test.json diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 467164d97..0b9735238 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -207,7 +207,7 @@ def merge_models_lists(model_lists): [ { **model, - "name": model["id"], + "name": model.get("name", model["id"]), "owned_by": "openai", "openai": model, "urlIdx": idx, @@ -319,6 +319,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = body.decode("utf-8") body = json.loads(body) + print(app.state.MODELS) + model = app.state.MODELS[body.get("model")] idx = model["urlIdx"] diff --git a/backend/main.py b/backend/main.py index fb65aa3e6..04ed5d2d7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -276,13 +276,11 @@ async def get_models(user=Depends(get_verified_user)): if app.state.config.ENABLE_OPENAI_API: openai_models = await get_openai_models() - openai_app.state.MODELS = openai_models openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: ollama_models = await get_ollama_models() - ollama_app.state.MODELS = ollama_models print(ollama_models) diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a96dbfaf4..6a122f4dc 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -27,7 +27,7 @@ export const getModels = async (token: string = '') => { let models = res?.data ?? []; - models = models.filter((models) => models).reduce((a, e, i, arr) => a.concat(e), []); + models = models.filter((models) => models).sort((a, b) => (a.name > b.name ? 1 : -1)); console.log(models); return models; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 83369b735..13c274f76 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -11,7 +11,6 @@ chats, config, type Model, - modelfiles, models, settings, showSidebar, @@ -63,24 +62,6 @@ let selectedModels = ['']; let atSelectedModel: Model | undefined; - let selectedModelfile = null; - $: selectedModelfile = - selectedModels.length === 1 && - $modelfiles.filter((modelfile) => modelfile.tagName === selectedModels[0]).length > 0 - ? $modelfiles.filter((modelfile) => modelfile.tagName === selectedModels[0])[0] - : null; - - let selectedModelfiles = {}; - $: selectedModelfiles = selectedModels.reduce((a, tagName, i, arr) => { - const modelfile = - $modelfiles.filter((modelfile) => modelfile.tagName === tagName)?.at(0) ?? undefined; - - return { - ...a, - ...(modelfile && { [tagName]: modelfile }) - }; - }, {}); - let chat = null; let tags = []; @@ -345,6 +326,7 @@ const hasImages = messages.some((message) => message.files?.some((file) => file.type === 'image') ); + if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) { toast.error( $i18n.t('Model {{modelName}} is not vision capable', { @@ -362,7 +344,7 @@ role: 'assistant', content: '', model: model.id, - modelName: model.custom_info?.name ?? model.name ?? model.id, + modelName: model.name ?? model.id, userContext: null, timestamp: Math.floor(Date.now() / 1000) // Unix epoch }; @@ -407,7 +389,7 @@ } responseMessage.userContext = userContext; - if (model?.external) { + if (model?.owned_by === 'openai') { await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); } else if (model) { await sendPromptOllama(model, prompt, responseMessageId, _chatId); @@ -956,10 +938,8 @@ ) + ' {{prompt}}', titleModelId, userPrompt, - titleModel?.external ?? false - ? titleModel?.source?.toLowerCase() === 'litellm' - ? `${LITELLM_API_BASE_URL}/v1` - : `${OPENAI_API_BASE_URL}` + titleModel?.owned_by === 'openai' ?? false + ? `${OPENAI_API_BASE_URL}` : `${OLLAMA_API_BASE_URL}/v1` ); @@ -1046,16 +1026,12 @@ 0} - suggestionPrompts={chatIdProp - ? [] - : selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions} {sendPrompt} {continueGeneration} {regenerateResponse} diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index 1dbb3f19c..4e2a383a1 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -1,7 +1,7 @@