From 14c0efe30088391e9b1ee2fa993e1ad096362f52 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 11 Jul 2024 18:47:38 -0700 Subject: [PATCH] feat: chat action integration --- src/lib/apis/index.ts | 39 +++++++++++++++++++ src/lib/components/chat/Chat.svelte | 36 ++++++++++++++++- src/lib/components/chat/Messages.svelte | 2 + .../chat/Messages/ResponseMessage.svelte | 2 + 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 9558e98f5..c2e90855b 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -104,6 +104,45 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => { return res; }; +type ChatActionForm = { + model: string; + messages: string[]; + chat_id: string; +}; + +export const chatAction = async (token: string, action_id: string, body: ChatActionForm) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/chat/actions/${action_id}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getTaskConfig = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 8d82acb76..085eddbde 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -52,7 +52,7 @@ import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; - import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis'; + import { chatCompleted, generateTitle, generateSearchQuery, chatAction } from '$lib/apis'; import Banner from '../common/Banner.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -401,6 +401,39 @@ } }; + const chatActionHandler = async (actionId, modelId, responseMessageId) => { + const res = await chatAction(localStorage.token, actionId, { + model: modelId, + messages: messages.map((m) => ({ + id: m.id, + role: m.role, + content: m.content, + info: m.info ? m.info : undefined, + timestamp: m.timestamp + })), + chat_id: $chatId, + session_id: $socket?.id, + id: responseMessageId + }).catch((error) => { + toast.error(error); + messages.at(-1).error = { content: error }; + return null; + }); + + if (res !== null) { + // Update chat history with the new messages + for (const message of res.messages) { + history.messages[message.id] = { + ...history.messages[message.id], + ...(history.messages[message.id].content !== message.content + ? { originalContent: history.messages[message.id].content } + : {}), + ...message + }; + } + } + }; + const getChatEventEmitter = async (modelId: string, chatId: string = '') => { return setInterval(() => { $socket?.emit('usage', { @@ -1533,6 +1566,7 @@ {sendPrompt} {continueGeneration} {regenerateResponse} + {chatActionHandler} /> diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index 699a8fbc1..6b71e6596 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -22,6 +22,7 @@ export let sendPrompt: Function; export let continueGeneration: Function; export let regenerateResponse: Function; + export let chatActionHandler: Function; export let user = $_user; export let prompt; @@ -335,6 +336,7 @@ copyToClipboard={copyToClipboardWithToast} {continueGeneration} {regenerateResponse} + {chatActionHandler} on:save={async (e) => { console.log('save', e); diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index cb938dd52..329b2b1cb 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -55,6 +55,7 @@ export let copyToClipboard: Function; export let continueGeneration: Function; export let regenerateResponse: Function; + export let chatActionHandler: Function; let model = null; $: model = $models.find((m) => m.id === message.model); @@ -1030,6 +1031,7 @@ ? 'visible' : 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button" on:click={() => { + chatActionHandler(action.id, message.model, message.id); console.log('action'); }} >