feat: __event_emitter__

This commit is contained in:
Timothy J. Baek 2024-07-01 20:05:02 -07:00
parent e5895af7a0
commit a07051f51b
2 changed files with 54 additions and 9 deletions

View File

@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app from apps.socket.main import sio, app as socket_app
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
OpenAIChatCompletionForm, OpenAIChatCompletionForm,
@ -277,7 +277,14 @@ def get_filter_function_ids(model):
async def get_function_call_response( async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user, model messages,
files,
tool_id,
template,
task_model_id,
user,
model,
__event_emitter__=None,
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
@ -414,6 +421,13 @@ async def get_function_call_response(
"__id__": tool_id, "__id__": tool_id,
} }
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": model,
}
if inspect.iscoroutinefunction(function): if inspect.iscoroutinefunction(function):
function_result = await function(**params) function_result = await function(**params)
else: else:
@ -437,7 +451,7 @@ async def get_function_call_response(
return None, None, False return None, None, False
async def chat_completion_functions_handler(body, model, user): async def chat_completion_functions_handler(body, model, user, __event_emitter__):
skip_files = None skip_files = None
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user):
**params, **params,
"__model__": model, "__model__": model,
} }
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
body = await inlet(**params) body = await inlet(**params)
@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user):
return body, {} return body, {}
async def chat_completion_tools_handler(body, model, user): async def chat_completion_tools_handler(body, model, user, __event_emitter__):
skip_files = None skip_files = None
contexts = [] contexts = []
@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user):
task_model_id=task_model_id, task_model_id=task_model_id,
user=user, user=user,
model=model, model=model,
__event_emitter__=__event_emitter__,
) )
print(file_handler) print(file_handler)
@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)}, content={"detail": str(e)},
) )
# Extract chat_id and message_id from the request body # Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
session_id = body["session_id"]
del body["session_id"]
chat_id = None chat_id = None
if "chat_id" in body: if "chat_id" in body:
chat_id = body["chat_id"] chat_id = body["chat_id"]
@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"] message_id = body["id"]
del body["id"] del body["id"]
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": data,
},
to=session_id,
)
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
data_items = [] data_items = []
@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
contexts = [] contexts = []
citations = [] citations = []
print(body)
try: try:
body, flags = await chat_completion_functions_handler(body, model, user) body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__
)
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
) )
try: try:
body, flags = await chat_completion_tools_handler(body, model, user) body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__
)
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))

View File

@ -163,6 +163,10 @@
}; };
window.addEventListener('message', onMessageHandler); window.addEventListener('message', onMessageHandler);
$socket.on('chat-events', async (data) => {
console.log(data);
});
if (!$chatId) { if (!$chatId) {
chatId.subscribe(async (value) => { chatId.subscribe(async (value) => {
if (!value) { if (!value) {
@ -177,6 +181,8 @@
return () => { return () => {
window.removeEventListener('message', onMessageHandler); window.removeEventListener('message', onMessageHandler);
$socket.off('chat-events');
}; };
}); });
@ -683,6 +689,7 @@
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
session_id: $socket?.id,
chat_id: $chatId, chat_id: $chatId,
id: responseMessageId id: responseMessageId
}); });
@ -984,6 +991,7 @@
max_tokens: $settings?.params?.max_tokens ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
session_id: $socket?.id,
chat_id: $chatId, chat_id: $chatId,
id: responseMessageId id: responseMessageId
}, },