enh: "stream" hook

This commit is contained in:
Timothy Jaeryang Baek 2025-02-25 01:00:29 -08:00
parent 205013da2f
commit 46c4da4864
2 changed files with 46 additions and 9 deletions

View File

@ -61,7 +61,12 @@ async def process_filter_functions(
try:
# Prepare parameters
sig = inspect.signature(handler)
params = {"body": form_data} | {
params = {"body": form_data}
if filter_type == "stream":
params = {"event": form_data}
params = params | {
k: v
for k, v in {
**extra_params,

View File

@ -1048,6 +1048,21 @@ async def process_chat_response(
):
return response
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
"__model__": metadata.get("model"),
}
filter_ids = get_sorted_filter_ids(form_data.get("model"))
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@ -1402,16 +1417,12 @@ async def process_chat_response(
("reasoning", "/reasoning"),
("thought", "/thought"),
("Thought", "/Thought"),
("|begin_of_thought|", "|end_of_thought|")
("|begin_of_thought|", "|end_of_thought|"),
]
code_interpreter_tags = [
("code_interpreter", "/code_interpreter")
]
code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
solution_tags = [
("|begin_of_solution|", "|end_of_solution|")
]
solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
try:
for event in events:
@ -1455,6 +1466,14 @@ async def process_chat_response(
try:
data = json.loads(data)
data, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
if "selected_model_id" in data:
model_id = data["selected_model_id"]
Chats.upsert_message_to_chat_by_id_and_message_id(
@ -1968,16 +1987,29 @@ async def process_chat_response(
return {"status": True, "task_id": task_id}
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
event, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=event,
extra_params=extra_params,
)
yield wrap_item(json.dumps(event))
async for data in original_generator:
data, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
yield data
return StreamingResponse(