diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 674ff50c4..f9028d667 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -39,6 +39,8 @@ from utils.utils import ( get_admin_user, ) +from utils.models import get_model_id_from_custom_model_id + from config import ( SRC_LOG_LEVELS, @@ -873,10 +875,10 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + model_id = get_model_id_from_custom_model_id(form_data.model) + model = model_id if url_idx == None: - model = form_data.model - if ":" not in model: model = f"{model}:latest" @@ -893,6 +895,13 @@ async def generate_chat_completion( r = None + # payload = { + # **form_data.model_dump_json(exclude_none=True).encode(), + # "model": model, + # "messages": form_data.messages, + + # } + log.debug( "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( form_data.model_dump_json(exclude_none=True).encode() diff --git a/backend/apps/web/models/models.py b/backend/apps/web/models/models.py index a50c9b5d2..2329b8f88 100644 --- a/backend/apps/web/models/models.py +++ b/backend/apps/web/models/models.py @@ -166,7 +166,9 @@ class ModelsTable: model = Model.get(Model.id == id) return ModelModel(**model_to_dict(model)) - except: + except Exception as e: + print(e) + return None def delete_model_by_id(self, id: str) -> bool: diff --git a/backend/apps/web/routers/models.py b/backend/apps/web/routers/models.py index 132f296ac..654d0d2fb 100644 --- a/backend/apps/web/routers/models.py +++ b/backend/apps/web/routers/models.py @@ -28,16 +28,24 @@ async def get_models(user=Depends(get_verified_user)): @router.post("/add", response_model=Optional[ModelModel]) -async def add_new_model(form_data: ModelForm, user=Depends(get_admin_user)): - model = Models.insert_new_model(form_data, user.id) - - if model: - return model - else: +async def add_new_model( + request: Request, form_data: ModelForm, user=Depends(get_admin_user) +): + if form_data.id in request.app.state.MODELS: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), + detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) + else: + model = Models.insert_new_model(form_data, user.id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) ############################ diff --git a/backend/constants.py b/backend/constants.py index be4d135b2..86875d2df 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." + NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." diff --git a/backend/utils/models.py b/backend/utils/models.py new file mode 100644 index 000000000..7a57b4fdb --- /dev/null +++ b/backend/utils/models.py @@ -0,0 +1,10 @@ +from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse + + +def get_model_id_from_custom_model_id(id: str): + model = Models.get_model_by_id(id) + + if model: + return model.id + else: + return id diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 13c274f76..6a4d998a3 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -194,7 +194,7 @@ await settings.set({ ..._settings, system: chatContent.system ?? _settings.system, - options: chatContent.options ?? _settings.options + params: chatContent.options ?? _settings.params }); autoScroll = true; await tick(); @@ -283,7 +283,7 @@ models: selectedModels, system: $settings.system ?? undefined, options: { - ...($settings.options ?? {}) + ...($settings.params ?? {}) }, messages: messages, history: history, @@ -431,7 +431,7 @@ // Prepare the base message object const baseMessage = { role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + content: message.content }; // Extract and format image URLs if any exist @@ -443,7 +443,6 @@ if (imageUrls && imageUrls.length > 0 && message.role === 'user') { baseMessage.images = imageUrls; } - return baseMessage; }); @@ -474,10 +473,10 @@ model: model, messages: messagesBody, options: { - ...($settings.options ?? {}), + ...($settings.params ?? {}), stop: - $settings?.options?.stop ?? undefined - ? $settings.options.stop.map((str) => + $settings?.params?.stop ?? undefined + ? $settings.params.stop.map((str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) ) : undefined @@ -718,18 +717,18 @@ : message?.raContent ?? message.content }) })), - seed: $settings?.options?.seed ?? undefined, + seed: $settings?.params?.seed ?? undefined, stop: - $settings?.options?.stop ?? undefined - ? $settings.options.stop.map((str) => + $settings?.params?.stop ?? undefined + ? $settings.params.stop.map((str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) ) : undefined, - temperature: $settings?.options?.temperature ?? undefined, - top_p: $settings?.options?.top_p ?? undefined, - num_ctx: $settings?.options?.num_ctx ?? undefined, - frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined, + temperature: $settings?.params?.temperature ?? undefined, + top_p: $settings?.params?.top_p ?? undefined, + num_ctx: $settings?.params?.num_ctx ?? undefined, + frequency_penalty: $settings?.params?.repeat_penalty ?? undefined, + max_tokens: $settings?.params?.num_predict ?? undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0 }, @@ -1045,7 +1044,7 @@ bind:files bind:prompt bind:autoScroll - bind:selectedModel={atSelectedModel} + bind:atSelectedModel {selectedModels} {messages} {submitPrompt} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 7bb736012..c9bc3e3fe 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -27,7 +27,8 @@ export let stopResponse: Function; export let autoScroll = true; - export let selectedAtModel: Model | undefined; + + export let atSelectedModel: Model | undefined; export let selectedModels: ['']; let chatTextAreaElement: HTMLTextAreaElement; @@ -52,7 +53,6 @@ export let messages = []; let speechRecognition; - let visionCapableState = 'all'; $: if (prompt) { @@ -62,19 +62,48 @@ } } - $: { - if (selectedAtModel || selectedModels) { - visionCapableState = checkModelsAreVisionCapable(); - if (visionCapableState === 'none') { - // Remove all image files - const fileCount = files.length; - files = files.filter((file) => file.type != 'image'); - if (files.length < fileCount) { - toast.warning($i18n.t('All selected models do not support image input, removed images')); - } + // $: { + // if (atSelectedModel || selectedModels) { + // visionCapableState = checkModelsAreVisionCapable(); + // if (visionCapableState === 'none') { + // // Remove all image files + // const fileCount = files.length; + // files = files.filter((file) => file.type != 'image'); + // if (files.length < fileCount) { + // toast.warning($i18n.t('All selected models do not support image input, removed images')); + // } + // } + // } + // } + + const checkModelsAreVisionCapable = () => { + let modelsToCheck = []; + if (atSelectedModel !== undefined) { + modelsToCheck = [atSelectedModel.id]; + } else { + modelsToCheck = selectedModels; + } + if (modelsToCheck.length == 0 || modelsToCheck[0] == '') { + return 'all'; + } + let visionCapableCount = 0; + for (const modelName of modelsToCheck) { + const model = $models.find((m) => m.id === modelName); + if (!model) { + continue; + } + if (model.custom_info?.meta.vision_capable ?? true) { + visionCapableCount++; } } - } + if (visionCapableCount == modelsToCheck.length) { + return 'all'; + } else if (visionCapableCount == 0) { + return 'none'; + } else { + return 'some'; + } + }; let mediaRecorder; let audioChunks = []; @@ -343,35 +372,6 @@ } }; - const checkModelsAreVisionCapable = () => { - let modelsToCheck = []; - if (selectedAtModel !== undefined) { - modelsToCheck = [selectedAtModel.id]; - } else { - modelsToCheck = selectedModels; - } - if (modelsToCheck.length == 0 || modelsToCheck[0] == '') { - return 'all'; - } - let visionCapableCount = 0; - for (const modelName of modelsToCheck) { - const model = $models.find((m) => m.id === modelName); - if (!model) { - continue; - } - if (model.custom_info?.meta.vision_capable ?? true) { - visionCapableCount++; - } - } - if (visionCapableCount == modelsToCheck.length) { - return 'all'; - } else if (visionCapableCount == 0) { - return 'none'; - } else { - return 'some'; - } - }; - onMount(() => { window.setTimeout(() => chatTextAreaElement?.focus(), 0); @@ -479,8 +479,8 @@