From 0ef27bfc5e74743d678bbbdbe29314c50221fbf1 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 11 Jul 2024 10:40:10 -0700 Subject: [PATCH] refac --- backend/apps/socket/main.py | 31 ++++++++++++++++++++ backend/main.py | 57 ++++++++++++++----------------------- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 123ff31cd..d3fece56d 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -137,3 +137,34 @@ async def disconnect(sid): await sio.emit("user-count", {"count": len(USER_POOL)}) else: print(f"Unknown session ID {sid} disconnected") + + +async def get_event_emitter(request_info): + async def __event_emitter__(event_data): + await sio.emit( + "chat-events", + { + "chat_id": request_info["chat_id"], + "message_id": request_info["id"], + "data": event_data, + }, + to=request_info["session_id"], + ) + + return __event_emitter__ + + +async def get_event_call(request_info): + async def __event_call__(event_data): + response = await sio.call( + "chat-events", + { + "chat_id": request_info["chat_id"], + "message_id": request_info["id"], + "data": event_data, + }, + to=request_info["session_id"], + ) + return response + + return __event_call__ diff --git a/backend/main.py b/backend/main.py index 89252e164..869f88908 100644 --- a/backend/main.py +++ b/backend/main.py @@ -29,7 +29,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import StreamingResponse, Response, RedirectResponse -from apps.socket.main import sio, app as socket_app +from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call from apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, @@ -632,24 +632,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] - async def __event_emitter__(data): - await sio.emit( - "chat-events", - { - "chat_id": chat_id, - "message_id": message_id, - "data": data, - }, - to=session_id, - ) - - async def __event_call__(data): - response = await sio.call( - "chat-events", - {"chat_id": chat_id, "message_id": message_id, "data": data}, - to=session_id, - ) - return response + __event_emitter__ = await get_event_emitter( + {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} + ) + __event_call__ = await get_event_call( + {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} + ) # Initialize data_items to store additional data to be sent to the client data_items = [] @@ -1107,24 +1095,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): else: pass - async def __event_emitter__(event_data): - await sio.emit( - "chat-events", - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "data": event_data, - }, - to=data["session_id"], - ) + __event_emitter__ = await get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) - async def __event_call__(event_data): - response = await sio.call( - "chat-events", - {"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data}, - to=data["session_id"], - ) - return response + __event_call__ = await get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) def get_priority(function_id): function = Functions.get_function_by_id(function_id)