mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 21:13:59 +00:00
feat: __event_emitter__
This commit is contained in:
parent
e5895af7a0
commit
a07051f51b
@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
|
|||||||
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
||||||
|
|
||||||
|
|
||||||
from apps.socket.main import app as socket_app
|
from apps.socket.main import sio, app as socket_app
|
||||||
from apps.ollama.main import (
|
from apps.ollama.main import (
|
||||||
app as ollama_app,
|
app as ollama_app,
|
||||||
OpenAIChatCompletionForm,
|
OpenAIChatCompletionForm,
|
||||||
@ -277,7 +277,14 @@ def get_filter_function_ids(model):
|
|||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(
|
async def get_function_call_response(
|
||||||
messages, files, tool_id, template, task_model_id, user, model
|
messages,
|
||||||
|
files,
|
||||||
|
tool_id,
|
||||||
|
template,
|
||||||
|
task_model_id,
|
||||||
|
user,
|
||||||
|
model,
|
||||||
|
__event_emitter__=None,
|
||||||
):
|
):
|
||||||
tool = Tools.get_tool_by_id(tool_id)
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
tools_specs = json.dumps(tool.specs, indent=2)
|
tools_specs = json.dumps(tool.specs, indent=2)
|
||||||
@ -414,6 +421,13 @@ async def get_function_call_response(
|
|||||||
"__id__": tool_id,
|
"__id__": tool_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "__event_emitter__" in sig.parameters:
|
||||||
|
# Call the function with the '__event_emitter__' parameter included
|
||||||
|
params = {
|
||||||
|
**params,
|
||||||
|
"__event_emitter__": model,
|
||||||
|
}
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(function):
|
if inspect.iscoroutinefunction(function):
|
||||||
function_result = await function(**params)
|
function_result = await function(**params)
|
||||||
else:
|
else:
|
||||||
@ -437,7 +451,7 @@ async def get_function_call_response(
|
|||||||
return None, None, False
|
return None, None, False
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_functions_handler(body, model, user):
|
async def chat_completion_functions_handler(body, model, user, __event_emitter__):
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
filter_ids = get_filter_function_ids(model)
|
filter_ids = get_filter_function_ids(model)
|
||||||
@ -503,6 +517,11 @@ async def chat_completion_functions_handler(body, model, user):
|
|||||||
**params,
|
**params,
|
||||||
"__model__": model,
|
"__model__": model,
|
||||||
}
|
}
|
||||||
|
if "__event_emitter__" in sig.parameters:
|
||||||
|
params = {
|
||||||
|
**params,
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
}
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(inlet):
|
if inspect.iscoroutinefunction(inlet):
|
||||||
body = await inlet(**params)
|
body = await inlet(**params)
|
||||||
@ -520,7 +539,7 @@ async def chat_completion_functions_handler(body, model, user):
|
|||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(body, model, user):
|
async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
contexts = []
|
contexts = []
|
||||||
@ -542,6 +561,7 @@ async def chat_completion_tools_handler(body, model, user):
|
|||||||
task_model_id=task_model_id,
|
task_model_id=task_model_id,
|
||||||
user=user,
|
user=user,
|
||||||
model=model,
|
model=model,
|
||||||
|
__event_emitter__=__event_emitter__,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(file_handler)
|
print(file_handler)
|
||||||
@ -614,7 +634,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
content={"detail": str(e)},
|
content={"detail": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract chat_id and message_id from the request body
|
# Extract session_id, chat_id and message_id from the request body
|
||||||
|
session_id = None
|
||||||
|
if "session_id" in body:
|
||||||
|
session_id = body["session_id"]
|
||||||
|
del body["session_id"]
|
||||||
chat_id = None
|
chat_id = None
|
||||||
if "chat_id" in body:
|
if "chat_id" in body:
|
||||||
chat_id = body["chat_id"]
|
chat_id = body["chat_id"]
|
||||||
@ -624,6 +648,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
message_id = body["id"]
|
message_id = body["id"]
|
||||||
del body["id"]
|
del body["id"]
|
||||||
|
|
||||||
|
async def __event_emitter__(data):
|
||||||
|
await sio.emit(
|
||||||
|
"chat-events",
|
||||||
|
{
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"data": data,
|
||||||
|
},
|
||||||
|
to=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize data_items to store additional data to be sent to the client
|
# Initialize data_items to store additional data to be sent to the client
|
||||||
data_items = []
|
data_items = []
|
||||||
|
|
||||||
@ -631,10 +666,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
contexts = []
|
contexts = []
|
||||||
citations = []
|
citations = []
|
||||||
|
|
||||||
print(body)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_functions_handler(body, model, user)
|
body, flags = await chat_completion_functions_handler(
|
||||||
|
body, model, user, __event_emitter__
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -642,7 +677,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_tools_handler(body, model, user)
|
body, flags = await chat_completion_tools_handler(
|
||||||
|
body, model, user, __event_emitter__
|
||||||
|
)
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
citations.extend(flags.get("citations", []))
|
citations.extend(flags.get("citations", []))
|
||||||
|
@ -163,6 +163,10 @@
|
|||||||
};
|
};
|
||||||
window.addEventListener('message', onMessageHandler);
|
window.addEventListener('message', onMessageHandler);
|
||||||
|
|
||||||
|
$socket.on('chat-events', async (data) => {
|
||||||
|
console.log(data);
|
||||||
|
});
|
||||||
|
|
||||||
if (!$chatId) {
|
if (!$chatId) {
|
||||||
chatId.subscribe(async (value) => {
|
chatId.subscribe(async (value) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -177,6 +181,8 @@
|
|||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
window.removeEventListener('message', onMessageHandler);
|
window.removeEventListener('message', onMessageHandler);
|
||||||
|
|
||||||
|
$socket.off('chat-events');
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -683,6 +689,7 @@
|
|||||||
keep_alive: $settings.keepAlive ?? undefined,
|
keep_alive: $settings.keepAlive ?? undefined,
|
||||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||||
files: files.length > 0 ? files : undefined,
|
files: files.length > 0 ? files : undefined,
|
||||||
|
session_id: $socket?.id,
|
||||||
chat_id: $chatId,
|
chat_id: $chatId,
|
||||||
id: responseMessageId
|
id: responseMessageId
|
||||||
});
|
});
|
||||||
@ -984,6 +991,7 @@
|
|||||||
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
||||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||||
files: files.length > 0 ? files : undefined,
|
files: files.length > 0 ? files : undefined,
|
||||||
|
session_id: $socket?.id,
|
||||||
chat_id: $chatId,
|
chat_id: $chatId,
|
||||||
id: responseMessageId
|
id: responseMessageId
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user