mirror of
https://github.com/open-webui/open-webui
synced 2025-04-08 14:49:46 +00:00
enh: "stream" hook
This commit is contained in:
parent
205013da2f
commit
46c4da4864
@ -61,7 +61,12 @@ async def process_filter_functions(
|
||||
try:
|
||||
# Prepare parameters
|
||||
sig = inspect.signature(handler)
|
||||
params = {"body": form_data} | {
|
||||
|
||||
params = {"body": form_data}
|
||||
if filter_type == "stream":
|
||||
params = {"event": form_data}
|
||||
|
||||
params = params | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
|
@ -1048,6 +1048,21 @@ async def process_chat_response(
|
||||
):
|
||||
return response
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": metadata.get("model"),
|
||||
}
|
||||
filter_ids = get_sorted_filter_ids(form_data.get("model"))
|
||||
|
||||
# Streaming response
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
@ -1402,16 +1417,12 @@ async def process_chat_response(
|
||||
("reasoning", "/reasoning"),
|
||||
("thought", "/thought"),
|
||||
("Thought", "/Thought"),
|
||||
("|begin_of_thought|", "|end_of_thought|")
|
||||
("|begin_of_thought|", "|end_of_thought|"),
|
||||
]
|
||||
|
||||
code_interpreter_tags = [
|
||||
("code_interpreter", "/code_interpreter")
|
||||
]
|
||||
code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
|
||||
|
||||
solution_tags = [
|
||||
("|begin_of_solution|", "|end_of_solution|")
|
||||
]
|
||||
solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
@ -1455,6 +1466,14 @@ async def process_chat_response(
|
||||
try:
|
||||
data = json.loads(data)
|
||||
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=filter_ids,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
|
||||
if "selected_model_id" in data:
|
||||
model_id = data["selected_model_id"]
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
@ -1968,16 +1987,29 @@ async def process_chat_response(
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
else:
|
||||
|
||||
# Fallback to the original response
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n"
|
||||
|
||||
for event in events:
|
||||
event, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=filter_ids,
|
||||
filter_type="stream",
|
||||
form_data=event,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
yield wrap_item(json.dumps(event))
|
||||
|
||||
async for data in original_generator:
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=filter_ids,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
yield data
|
||||
|
||||
return StreamingResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user