From d6fd2a822818e714eb148db61b3edce72fcf1de2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 12 Jun 2024 21:18:53 -0700 Subject: [PATCH] refac --- backend/main.py | 68 ++++++ src/lib/apis/index.ts | 40 ++++ src/lib/components/chat/Chat.svelte | 215 +++++++++--------- src/lib/components/chat/MessageInput.svelte | 7 +- .../chat/MessageInput/CallOverlay.svelte | 145 +++++++++--- src/lib/components/chat/Messages.svelte | 2 +- .../components/chat/Settings/Interface.svelte | 32 +++ src/lib/utils/index.ts | 2 +- 8 files changed, 371 insertions(+), 140 deletions(-) diff --git a/backend/main.py b/backend/main.py index de8827d12..9de4d7111 100644 --- a/backend/main.py +++ b/backend/main.py @@ -494,6 +494,9 @@ def filter_pipeline(payload, user): if "title" in payload: del payload["title"] + if "task" in payload: + del payload["task"] + return payload @@ -835,6 +838,71 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, + "task": True, + } + + print(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/task/emoji/completions") +async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): + print("generate_emoji") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = ''' +You are a perceptive assistant skilled at interpreting emotions from a provided message. Your task is to reflect the speaker's likely facial expression through a fitting emoji. Prioritize using diverse facial expression emojis to convey the nuanced emotions expressed in the text. Please avoid using generic or overly ambiguous emojis like "🤔", and instead, choose ones that vividly represent the speaker's mood or reaction. + +Message: """{{prompt}}""" +''' + + content = title_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 4, + "chat_id": form_data.get("chat_id", None), + "task": True, } print(payload) diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c40815611..1575885e9 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -205,6 +205,46 @@ export const generateTitle = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; }; +export const generateEmoji = async ( + token: string = '', + model: string, + prompt: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? null; +}; + export const generateSearchQuery = async ( token: string = '', model: string, diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 359056cfd..aa1462ff4 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -64,6 +64,8 @@ export let chatIdProp = ''; let loaded = false; + const eventTarget = new EventTarget(); + let stopResponseFlag = false; let autoScroll = true; let processing = ''; @@ -300,7 +302,7 @@ // Chat functions ////////////////////////// - const submitPrompt = async (userPrompt, _user = null) => { + const submitPrompt = async (userPrompt, { _raw = false } = {}) => { let _responses = []; console.log('submitPrompt', $chatId); @@ -344,7 +346,6 @@ parentId: messages.length !== 0 ? messages.at(-1).id : null, childrenIds: [], role: 'user', - user: _user ?? undefined, content: userPrompt, files: _files.length > 0 ? _files : undefined, timestamp: Math.floor(Date.now() / 1000), // Unix epoch @@ -362,15 +363,13 @@ // Wait until history/message have been updated await tick(); - - // Send prompt - _responses = await sendPrompt(userPrompt, userMessageId); + _responses = await sendPrompt(userPrompt, userMessageId, { newChat: true }); } return _responses; }; - const sendPrompt = async (prompt, parentId, modelId = null, newChat = true) => { + const sendPrompt = async (prompt, parentId, { modelId = null, newChat = false } = {}) => { let _responses = []; // If modelId is provided, use it, else use selected model @@ -490,7 +489,6 @@ responseMessage.userContext = userContext; const chatEventEmitter = await getChatEventEmitter(model.id, _chatId); - if (webSearchEnabled) { await getWebSearchResults(model.id, parentId, responseMessageId); } @@ -503,8 +501,6 @@ } _responses.push(_response); - console.log('chatEventEmitter', chatEventEmitter); - if (chatEventEmitter) clearInterval(chatEventEmitter); } else { toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); @@ -513,88 +509,9 @@ ); await chats.set(await getChatList(localStorage.token)); - return _responses; }; - const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { - const responseMessage = history.messages[responseId]; - - responseMessage.statusHistory = [ - { - done: false, - action: 'web_search', - description: $i18n.t('Generating search query') - } - ]; - messages = messages; - - const prompt = history.messages[parentId].content; - let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( - (error) => { - console.log(error); - return prompt; - } - ); - - if (!searchQuery) { - toast.warning($i18n.t('No search query generated')); - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search query generated' - }); - - messages = messages; - } - - responseMessage.statusHistory.push({ - done: false, - action: 'web_search', - description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) - }); - messages = messages; - - const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { - console.log(error); - toast.error(error); - - return null; - }); - - if (results) { - responseMessage.statusHistory.push({ - done: true, - action: 'web_search', - description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), - query: searchQuery, - urls: results.filenames - }); - - if (responseMessage?.files ?? undefined === undefined) { - responseMessage.files = []; - } - - responseMessage.files.push({ - collection_name: results.collection_name, - name: searchQuery, - type: 'web_search_results', - urls: results.filenames - }); - - messages = messages; - } else { - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search results found' - }); - messages = messages; - } - }; - const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => { let _response = null; @@ -676,6 +593,8 @@ array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index ); + eventTarget.dispatchEvent(new CustomEvent('chat:start')); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model.id, messages: messagesBody, @@ -745,6 +664,9 @@ continue; } else { responseMessage.content += data.message.content; + eventTarget.dispatchEvent( + new CustomEvent('chat', { detail: { content: data.message.content } }) + ); messages = messages; } } else { @@ -771,21 +693,13 @@ messages = messages; if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification( - selectedModelfile - ? `${ - selectedModelfile.title.charAt(0).toUpperCase() + - selectedModelfile.title.slice(1) - }` - : `${model.id}`, - { - body: responseMessage.content, - icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png` - } - ); + const notification = new Notification(`${model.id}`, { + body: responseMessage.content, + icon: `${WEBUI_BASE_URL}/static/favicon.png` + }); } - if ($settings.responseAutoCopy) { + if ($settings?.responseAutoCopy ?? false) { copyToClipboard(responseMessage.content); } @@ -846,6 +760,7 @@ stopResponseFlag = false; await tick(); + eventTarget.dispatchEvent(new CustomEvent('chat:finish')); if (autoScroll) { scrollToBottom(); @@ -887,6 +802,8 @@ scrollToBottom(); + eventTarget.dispatchEvent(new CustomEvent('chat:start')); + try { const [res, controller] = await generateOpenAIChatCompletion( localStorage.token, @@ -1007,6 +924,7 @@ continue; } else { responseMessage.content += value; + eventTarget.dispatchEvent(new CustomEvent('chat', { detail: { content: value } })); messages = messages; } @@ -1057,6 +975,8 @@ stopResponseFlag = false; await tick(); + eventTarget.dispatchEvent(new CustomEvent('chat:finish')); + if (autoScroll) { scrollToBottom(); } @@ -1123,9 +1043,12 @@ let userPrompt = userMessage.content; if ((userMessage?.models ?? [...selectedModels]).length == 1) { - await sendPrompt(userPrompt, userMessage.id, undefined, false); + // If user message has only one model selected, sendPrompt automatically selects it for regeneration + await sendPrompt(userPrompt, userMessage.id); } else { - await sendPrompt(userPrompt, userMessage.id, message.model, false); + // If there are multiple models selected, use the model of the response message for regeneration + // e.g. many model chat + await sendPrompt(userPrompt, userMessage.id, { modelId: message.model }); } } }; @@ -1191,6 +1114,84 @@ } }; + const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { + const responseMessage = history.messages[responseId]; + + responseMessage.statusHistory = [ + { + done: false, + action: 'web_search', + description: $i18n.t('Generating search query') + } + ]; + messages = messages; + + const prompt = history.messages[parentId].content; + let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( + (error) => { + console.log(error); + return prompt; + } + ); + + if (!searchQuery) { + toast.warning($i18n.t('No search query generated')); + responseMessage.statusHistory.push({ + done: true, + error: true, + action: 'web_search', + description: 'No search query generated' + }); + + messages = messages; + } + + responseMessage.statusHistory.push({ + done: false, + action: 'web_search', + description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) + }); + messages = messages; + + const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { + console.log(error); + toast.error(error); + + return null; + }); + + if (results) { + responseMessage.statusHistory.push({ + done: true, + action: 'web_search', + description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), + query: searchQuery, + urls: results.filenames + }); + + if (responseMessage?.files ?? undefined === undefined) { + responseMessage.files = []; + } + + responseMessage.files.push({ + collection_name: results.collection_name, + name: searchQuery, + type: 'web_search_results', + urls: results.filenames + }); + + messages = messages; + } else { + responseMessage.statusHistory.push({ + done: true, + error: true, + action: 'web_search', + description: 'No search results found' + }); + messages = messages; + } + }; + const getTags = async () => { return await getTagsById(localStorage.token, $chatId).catch(async (error) => { return []; @@ -1206,7 +1207,13 @@ - + {#if !chatIdProp || (loaded && chatIdProp)}
{ @@ -467,7 +466,7 @@ document.getElementById('chat-textarea')?.focus(); if ($settings?.speechAutoSend ?? false) { - submitPrompt(prompt, user); + submitPrompt(prompt); } }} /> @@ -476,7 +475,7 @@ class="w-full flex gap-1.5" on:submit|preventDefault={() => { // check if selectedModels support image input - submitPrompt(prompt, user); + submitPrompt(prompt); }} >
{ - audioElement.muted = false; - }) - .catch((error) => { - toast.error(error); - }); + audioElement + .play() + .then(() => { + audioElement.muted = false; + }) + .catch((error) => { + toast.error(error); + }); - audioElement.onended = async (e) => { - await new Promise((r) => setTimeout(r, 300)); + audioElement.onended = async (e) => { + await new Promise((r) => setTimeout(r, 300)); - if (Object.keys(assistantAudio).length - 1 === idx) { - assistantSpeaking = false; - } + if (Object.keys(assistantAudio).length - 1 === idx) { + assistantSpeaking = false; + } - res(e); - }; + res(e); + }; + } }); } else { return Promise.resolve(); @@ -200,15 +213,8 @@ console.log(res.text); if (res.text !== '') { - const _responses = await submitPrompt(res.text); + const _responses = await submitPrompt(res.text, { _raw: true }); console.log(_responses); - - if (_responses.at(0)) { - const content = _responses[0]; - if ((content ?? '').trim() !== '') { - assistantSpeakingHandler(content); - } - } } } }; @@ -216,6 +222,23 @@ const assistantSpeakingHandler = async (content) => { assistantSpeaking = true; + if (modelId && ($settings?.showEmojiInCall ?? false)) { + console.log('Generating emoji'); + const res = await generateEmoji(localStorage.token, modelId, content, chatId).catch( + (error) => { + console.error(error); + return null; + } + ); + + if (res) { + console.log(res); + if (/\p{Extended_Pictographic}/u.test(res)) { + emoji = res.match(/\p{Extended_Pictographic}/gu)[0]; + } + } + } + if (($config.audio.tts.engine ?? '') == '') { let voices = []; const getVoicesLoop = setInterval(async () => { @@ -237,6 +260,10 @@ } speechSynthesis.speak(currentUtterance); + + currentUtterance.onend = async () => { + assistantSpeaking = false; + }; } }, 100); } else if ($config.audio.tts.engine === 'openai') { @@ -280,15 +307,22 @@ const audio = new Audio(blobUrl); assistantAudio[idx] = audio; lastPlayedAudioPromise = lastPlayedAudioPromise.then(() => playAudio(idx)); + + if (idx === sentences.length - 1) { + lastPlayedAudioPromise.then(() => { + assistantSpeaking = false; + }); + } } } } }; - const stopRecordingCallback = async () => { + const stopRecordingCallback = async (_continue = true) => { if ($showCallOverlay) { if (confirmed) { loading = true; + emoji = null; if (cameraStream) { const imageUrl = takeScreenshot(); @@ -310,7 +344,9 @@ audioChunks = []; mediaRecorder = false; - startRecording(); + if (_continue) { + startRecording(); + } } else { audioChunks = []; mediaRecorder = false; @@ -443,7 +479,30 @@ startRecording(); } else { stopCamera(); + stopAllAudio(); + stopRecordingCallback(false); } + + onMount(() => { + console.log(eventTarget); + + eventTarget.addEventListener('chat:start', async (e) => { + console.log('Chat start event:', e.detail); + message = ''; + }); + + eventTarget.addEventListener('chat', async (e) => { + const { content } = e.detail; + + message += content; + console.log('Chat event:', message); + }); + + eventTarget.addEventListener('chat:finish', async (e) => { + console.log('Chat finish event:', e.detail); + message = ''; + }); + }); {#if $showCallOverlay} @@ -492,6 +551,19 @@ r="3" /> + {:else if emoji} +
+ {emoji} +
{:else}
+ {:else if emoji} +
+ {emoji} +
{:else}
+
{$i18n.t('Display Emoji in Call')}
+ + +
+
+ {#if !$settings.chatBubble}
diff --git a/src/lib/utils/index.ts b/src/lib/utils/index.ts index 15ac73f1b..830f315bc 100644 --- a/src/lib/utils/index.ts +++ b/src/lib/utils/index.ts @@ -436,7 +436,7 @@ export const removeEmojis = (str) => { export const extractSentences = (text) => { // Split the paragraph into sentences based on common punctuation marks - const sentences = text.split(/(?<=[.!?])/); + const sentences = text.split(/(?<=[.!?])\s+/); return sentences .map((sentence) => removeEmojis(sentence.trim()))