From 8c2ba7f7eabc48078d99a64c7df933ff57c8327c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 15 Aug 2024 17:28:43 +0200 Subject: [PATCH] enh: Actions `__webui__` flag support --- backend/main.py | 6 ++ src/lib/components/chat/Chat.svelte | 3 +- src/lib/components/chat/Messages.svelte | 8 +- .../chat/Messages/RateComment.svelte | 5 +- .../chat/Messages/ResponseMessage.svelte | 87 +++++++++++++++++-- 5 files changed, 99 insertions(+), 10 deletions(-) diff --git a/backend/main.py b/backend/main.py index c7bdb6284..838556c40 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1026,6 +1026,10 @@ async def get_all_models(): function_module, _, _ = load_function_module_by_id(action_id) webui_app.state.FUNCTIONS[action_id] = function_module + __webui__ = False + if hasattr(function_module, "__webui__"): + __webui__ = function_module.__webui__ + if hasattr(function_module, "actions"): actions = function_module.actions model["actions"].extend( @@ -1039,6 +1043,7 @@ async def get_all_models(): "icon_url": _action.get( "icon_url", action.meta.manifest.get("icon_url", None) ), + **({"__webui__": __webui__} if __webui__ else {}), } for _action in actions ] @@ -1050,6 +1055,7 @@ async def get_all_models(): "name": action.name, "description": action.meta.description, "icon_url": action.meta.manifest.get("icon_url", None), + **({"__webui__": __webui__} if __webui__ else {}), } ) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index b81315ced..9d28df0c1 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -430,7 +430,7 @@ } }; - const chatActionHandler = async (chatId, actionId, modelId, responseMessageId) => { + const chatActionHandler = async (chatId, actionId, modelId, responseMessageId, event = null) => { const res = await chatAction(localStorage.token, actionId, { model: modelId, messages: messages.map((m) => ({ @@ -440,6 +440,7 @@ info: m.info ? m.info : undefined, timestamp: m.timestamp })), + ...(event ? { event: event } : {}), chat_id: chatId, session_id: $socket?.id, id: responseMessageId diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index fb6754e86..9cc7f9526 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -342,7 +342,13 @@ {continueGeneration} {regenerateResponse} on:action={async (e) => { - await chatActionHandler(chatId, e.detail, message.model, message.id); + console.log('action', e); + if (typeof e.detail === 'string') { + await chatActionHandler(chatId, e.detail, message.model, message.id); + } else { + const { id, event } = e.detail; + await chatActionHandler(chatId, id, message.model, message.id, event); + } }} on:save={async (e) => { console.log('save', e); diff --git a/src/lib/components/chat/Messages/RateComment.svelte b/src/lib/components/chat/Messages/RateComment.svelte index 78eddecd9..73261ce66 100644 --- a/src/lib/components/chat/Messages/RateComment.svelte +++ b/src/lib/components/chat/Messages/RateComment.svelte @@ -57,7 +57,10 @@ message.annotation.reason = selectedReason; message.annotation.comment = comment; - dispatch('submit'); + dispatch('submit', { + reason: selectedReason, + comment: comment + }); toast.success($i18n.t('Thanks for your feedback!')); show = false; diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index c51510060..8e31a17c8 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -821,10 +821,24 @@ ?.annotation?.rating ?? null) === 1 ? 'bg-gray-100 dark:bg-gray-800' : ''} dark:hover:text-white hover:text-black transition" - on:click={() => { - rateMessage(message.id, 1); - showRateComment = true; + on:click={async () => { + await rateMessage(message.id, 1); + (model?.actions ?? []) + .filter((action) => action?.__webui__ ?? false) + .forEach((action) => { + dispatch('action', { + id: action.id, + event: { + id: 'good-response', + data: { + messageId: message.id + } + } + }); + }); + + showRateComment = true; window.setTimeout(() => { document .getElementById(`message-feedback-${message.id}`) @@ -856,8 +870,23 @@ ?.annotation?.rating ?? null) === -1 ? 'bg-gray-100 dark:bg-gray-800' : ''} dark:hover:text-white hover:text-black transition" - on:click={() => { - rateMessage(message.id, -1); + on:click={async () => { + await rateMessage(message.id, -1); + + (model?.actions ?? []) + .filter((action) => action?.__webui__ ?? false) + .forEach((action) => { + dispatch('action', { + id: action.id, + event: { + id: 'bad-response', + data: { + messageId: message.id + } + } + }); + }); + showRateComment = true; window.setTimeout(() => { document @@ -891,6 +920,20 @@ : 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button" on:click={() => { continueGeneration(); + + (model?.actions ?? []) + .filter((action) => action?.__webui__ ?? false) + .forEach((action) => { + dispatch('action', { + id: action.id, + event: { + id: 'continue-response', + data: { + messageId: message.id + } + } + }); + }); }} > { showRateComment = false; regenerateResponse(message); + + (model?.actions ?? []) + .filter((action) => action?.__webui__ ?? false) + .forEach((action) => { + dispatch('action', { + id: action.id, + event: { + id: 'regenerate-response', + data: { + messageId: message.id + } + } + }); + }); }} > - {#each model?.actions ?? [] as action} + {#each (model?.actions ?? []).filter((action) => !(action?.__webui__ ?? false)) as action}