From a07051f51bf54ce638cbef3781dffcea34696f14 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 1 Jul 2024 20:05:02 -0700 Subject: [PATCH] feat: __event_emitter__ --- backend/main.py | 55 ++++++++++++++++++++++++----- src/lib/components/chat/Chat.svelte | 8 +++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/backend/main.py b/backend/main.py index 09d80f012..8b8cac2aa 100644 --- a/backend/main.py +++ b/backend/main.py @@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import StreamingResponse, Response, RedirectResponse -from apps.socket.main import app as socket_app +from apps.socket.main import sio, app as socket_app from apps.ollama.main import ( app as ollama_app, OpenAIChatCompletionForm, @@ -277,7 +277,14 @@ def get_filter_function_ids(model): async def get_function_call_response( - messages, files, tool_id, template, task_model_id, user, model + messages, + files, + tool_id, + template, + task_model_id, + user, + model, + __event_emitter__=None, ): tool = Tools.get_tool_by_id(tool_id) tools_specs = json.dumps(tool.specs, indent=2) @@ -414,6 +421,13 @@ async def get_function_call_response( "__id__": tool_id, } + if "__event_emitter__" in sig.parameters: + # Call the function with the '__event_emitter__' parameter included + params = { + **params, + "__event_emitter__": model, + } + if inspect.iscoroutinefunction(function): function_result = await function(**params) else: @@ -437,7 +451,7 @@ async def get_function_call_response( return None, None, False -async def chat_completion_functions_handler(body, model, user): +async def chat_completion_functions_handler(body, model, user, __event_emitter__): skip_files = None filter_ids = get_filter_function_ids(model) @@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user): **params, "__model__": model, } + if "__event_emitter__" in sig.parameters: + params = { + **params, + "__event_emitter__": __event_emitter__, + } if inspect.iscoroutinefunction(inlet): body = await inlet(**params) @@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user): return body, {} -async def chat_completion_tools_handler(body, model, user): +async def chat_completion_tools_handler(body, model, user, __event_emitter__): skip_files = None contexts = [] @@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user): task_model_id=task_model_id, user=user, model=model, + __event_emitter__=__event_emitter__, ) print(file_handler) @@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) - # Extract chat_id and message_id from the request body + # Extract session_id, chat_id and message_id from the request body + session_id = None + if "session_id" in body: + session_id = body["session_id"] + del body["session_id"] chat_id = None if "chat_id" in body: chat_id = body["chat_id"] @@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] + async def __event_emitter__(data): + await sio.emit( + "chat-events", + { + "chat_id": chat_id, + "message_id": message_id, + "data": data, + }, + to=session_id, + ) + # Initialize data_items to store additional data to be sent to the client data_items = [] @@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): contexts = [] citations = [] - print(body) - try: - body, flags = await chat_completion_functions_handler(body, model, user) + body, flags = await chat_completion_functions_handler( + body, model, user, __event_emitter__ + ) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, @@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) try: - body, flags = await chat_completion_tools_handler(body, model, user) + body, flags = await chat_completion_tools_handler( + body, model, user, __event_emitter__ + ) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 056432d42..769046367 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -163,6 +163,10 @@ }; window.addEventListener('message', onMessageHandler); + $socket.on('chat-events', async (data) => { + console.log(data); + }); + if (!$chatId) { chatId.subscribe(async (value) => { if (!value) { @@ -177,6 +181,8 @@ return () => { window.removeEventListener('message', onMessageHandler); + + $socket.off('chat-events'); }; }); @@ -683,6 +689,7 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, + session_id: $socket?.id, chat_id: $chatId, id: responseMessageId }); @@ -984,6 +991,7 @@ max_tokens: $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, + session_id: $socket?.id, chat_id: $chatId, id: responseMessageId },