diff --git a/backend/main.py b/backend/main.py index 8426536ef..79915e8f1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -302,6 +302,7 @@ async def get_function_call_response( user, model, __event_emitter__=None, + __event_call__=None, ): tool = Tools.get_tool_by_id(tool_id) tools_specs = json.dumps(tool.specs, indent=2) @@ -445,6 +446,13 @@ async def get_function_call_response( "__event_emitter__": __event_emitter__, } + if "__event_call__" in sig.parameters: + # Call the function with the '__event_call__' parameter included + params = { + **params, + "__event_call__": __event_call__, + } + if inspect.iscoroutinefunction(function): function_result = await function(**params) else: @@ -468,7 +476,9 @@ async def get_function_call_response( return None, None, False -async def chat_completion_functions_handler(body, model, user, __event_emitter__): +async def chat_completion_functions_handler( + body, model, user, __event_emitter__, __event_call__ +): skip_files = None filter_ids = get_filter_function_ids(model) @@ -534,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__ **params, "__model__": model, } + if "__event_emitter__" in sig.parameters: params = { **params, "__event_emitter__": __event_emitter__, } + if "__event_call__" in sig.parameters: + params = { + **params, + "__event_call__": __event_call__, + } + if inspect.iscoroutinefunction(inlet): body = await inlet(**params) else: @@ -556,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__ return body, {} -async def chat_completion_tools_handler(body, model, user, __event_emitter__): +async def chat_completion_tools_handler( + body, model, user, __event_emitter__, __event_call__ +): skip_files = None contexts = [] @@ -579,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__): user=user, model=model, __event_emitter__=__event_emitter__, + __event_call__=__event_call__, ) print(file_handler) @@ -676,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): to=session_id, ) + async def __event_call__(data): + response = await sio.call( + "chat-events", + {"chat_id": chat_id, "message_id": message_id, "data": data}, + to=session_id, + ) + return response + # Initialize data_items to store additional data to be sent to the client data_items = [] @@ -685,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_functions_handler( - body, model, user, __event_emitter__ + body, model, user, __event_emitter__, __event_call__ ) except Exception as e: return JSONResponse( @@ -695,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_tools_handler( - body, model, user, __event_emitter__ + body, model, user, __event_emitter__, __event_call__ ) contexts.extend(flags.get("contexts", [])) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 2d31b1462..5c39dc1c0 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -61,6 +61,7 @@ import CallOverlay from './MessageInput/CallOverlay.svelte'; import { error } from '@sveltejs/kit'; import ChatControls from './ChatControls.svelte'; + import EventConfirmDialog from '../common/ConfirmDialog.svelte'; const i18n: Writable = getContext('i18n'); @@ -74,6 +75,11 @@ let processing = ''; let messagesContainerElement: HTMLDivElement; + let showEventConfirmation = false; + let eventConfirmationTitle = ''; + let eventConfirmationMessage = ''; + let eventCallback = null; + let showModelSelector = true; let selectedModels = ['']; @@ -129,7 +135,7 @@ })(); } - const chatEventHandler = async (event) => { + const chatEventHandler = async (event, cb) => { if (event.chat_id === $chatId) { await tick(); console.log(event); @@ -139,17 +145,23 @@ const data = event?.data?.data ?? null; if (type === 'status') { - if (message.statusHistory) { + if (message?.statusHistory) { message.statusHistory.push(data); } else { message.statusHistory = [data]; } } else if (type === 'citation') { - if (message.citations) { + if (message?.citations) { message.citations.push(data); } else { message.citations = [data]; } + } else if (type === 'confirmation') { + eventCallback = cb; + showEventConfirmation = true; + + eventConfirmationTitle = data.title; + eventConfirmationMessage = data.message; } else { console.log('Unknown message type', data); } @@ -1392,6 +1404,18 @@