mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 13:07:25 +00:00
feat: __event_emitter__
This commit is contained in:
parent
e5895af7a0
commit
a07051f51b
@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
||||
|
||||
|
||||
from apps.socket.main import app as socket_app
|
||||
from apps.socket.main import sio, app as socket_app
|
||||
from apps.ollama.main import (
|
||||
app as ollama_app,
|
||||
OpenAIChatCompletionForm,
|
||||
@ -277,7 +277,14 @@ def get_filter_function_ids(model):
|
||||
|
||||
|
||||
async def get_function_call_response(
|
||||
messages, files, tool_id, template, task_model_id, user, model
|
||||
messages,
|
||||
files,
|
||||
tool_id,
|
||||
template,
|
||||
task_model_id,
|
||||
user,
|
||||
model,
|
||||
__event_emitter__=None,
|
||||
):
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
@ -414,6 +421,13 @@ async def get_function_call_response(
|
||||
"__id__": tool_id,
|
||||
}
|
||||
|
||||
if "__event_emitter__" in sig.parameters:
|
||||
# Call the function with the '__event_emitter__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__event_emitter__": model,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(function):
|
||||
function_result = await function(**params)
|
||||
else:
|
||||
@ -437,7 +451,7 @@ async def get_function_call_response(
|
||||
return None, None, False
|
||||
|
||||
|
||||
async def chat_completion_functions_handler(body, model, user):
|
||||
async def chat_completion_functions_handler(body, model, user, __event_emitter__):
|
||||
skip_files = None
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user):
|
||||
**params,
|
||||
"__model__": model,
|
||||
}
|
||||
if "__event_emitter__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user):
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(body, model, user):
|
||||
async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
||||
skip_files = None
|
||||
|
||||
contexts = []
|
||||
@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user):
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
model=model,
|
||||
__event_emitter__=__event_emitter__,
|
||||
)
|
||||
|
||||
print(file_handler)
|
||||
@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
# Extract chat_id and message_id from the request body
|
||||
# Extract session_id, chat_id and message_id from the request body
|
||||
session_id = None
|
||||
if "session_id" in body:
|
||||
session_id = body["session_id"]
|
||||
del body["session_id"]
|
||||
chat_id = None
|
||||
if "chat_id" in body:
|
||||
chat_id = body["chat_id"]
|
||||
@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
message_id = body["id"]
|
||||
del body["id"]
|
||||
|
||||
async def __event_emitter__(data):
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"data": data,
|
||||
},
|
||||
to=session_id,
|
||||
)
|
||||
|
||||
# Initialize data_items to store additional data to be sent to the client
|
||||
data_items = []
|
||||
|
||||
@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
print(body)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_functions_handler(body, model, user)
|
||||
body, flags = await chat_completion_functions_handler(
|
||||
body, model, user, __event_emitter__
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_tools_handler(body, model, user)
|
||||
body, flags = await chat_completion_tools_handler(
|
||||
body, model, user, __event_emitter__
|
||||
)
|
||||
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
|
@ -163,6 +163,10 @@
|
||||
};
|
||||
window.addEventListener('message', onMessageHandler);
|
||||
|
||||
$socket.on('chat-events', async (data) => {
|
||||
console.log(data);
|
||||
});
|
||||
|
||||
if (!$chatId) {
|
||||
chatId.subscribe(async (value) => {
|
||||
if (!value) {
|
||||
@ -177,6 +181,8 @@
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('message', onMessageHandler);
|
||||
|
||||
$socket.off('chat-events');
|
||||
};
|
||||
});
|
||||
|
||||
@ -683,6 +689,7 @@
|
||||
keep_alive: $settings.keepAlive ?? undefined,
|
||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||
files: files.length > 0 ? files : undefined,
|
||||
session_id: $socket?.id,
|
||||
chat_id: $chatId,
|
||||
id: responseMessageId
|
||||
});
|
||||
@ -984,6 +991,7 @@
|
||||
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||
files: files.length > 0 ? files : undefined,
|
||||
session_id: $socket?.id,
|
||||
chat_id: $chatId,
|
||||
id: responseMessageId
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user