refac
This commit is contained in:
@@ -2416,8 +2416,79 @@ def get_event_emitter_and_caller(metadata):
|
||||
return event_emitter, event_caller
|
||||
|
||||
|
||||
async def background_tasks_handler(request, form_data, user, metadata, tasks):
|
||||
def build_chat_response_context(
|
||||
request, form_data, user, model, metadata, tasks, events
|
||||
):
|
||||
event_emitter, event_caller = get_event_emitter_and_caller(metadata)
|
||||
return {
|
||||
"request": request,
|
||||
"form_data": form_data,
|
||||
"user": user,
|
||||
"model": model,
|
||||
"metadata": metadata,
|
||||
"tasks": tasks,
|
||||
"events": events,
|
||||
"event_emitter": event_emitter,
|
||||
"event_caller": event_caller,
|
||||
}
|
||||
|
||||
|
||||
def get_response_data(response):
|
||||
if isinstance(response, list) and len(response) == 1:
|
||||
# If the response is a single-item list, unwrap it #17213
|
||||
response = response[0]
|
||||
|
||||
if isinstance(response, JSONResponse):
|
||||
if isinstance(response.body, bytes):
|
||||
try:
|
||||
response_data = json.loads(response.body.decode("utf-8", "replace"))
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"error": {"detail": "Invalid JSON response"}}
|
||||
else:
|
||||
response_data = response
|
||||
elif isinstance(response, dict):
|
||||
response_data = response
|
||||
else:
|
||||
response_data = None
|
||||
|
||||
return response, response_data
|
||||
|
||||
|
||||
def merge_events_into_response(response_data, events):
|
||||
if events and isinstance(events, list):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
|
||||
return {
|
||||
**extra_response,
|
||||
**response_data,
|
||||
}
|
||||
return response_data
|
||||
|
||||
|
||||
def build_response_object(response, response_data):
|
||||
if isinstance(response, dict):
|
||||
return response_data
|
||||
if isinstance(response, JSONResponse):
|
||||
return JSONResponse(
|
||||
content=response_data,
|
||||
headers=response.headers,
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def background_tasks_handler(ctx):
|
||||
request = ctx["request"]
|
||||
form_data = ctx["form_data"]
|
||||
user = ctx["user"]
|
||||
metadata = ctx["metadata"]
|
||||
tasks = ctx["tasks"]
|
||||
event_emitter = ctx["event_emitter"]
|
||||
|
||||
message = None
|
||||
messages = []
|
||||
@@ -2633,184 +2704,144 @@ async def background_tasks_handler(request, form_data, user, metadata, tasks):
|
||||
pass
|
||||
|
||||
|
||||
async def process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
):
|
||||
event_emitter, event_caller = get_event_emitter_and_caller(metadata)
|
||||
async def non_streaming_chat_response_handler(response, ctx):
|
||||
request = ctx["request"]
|
||||
|
||||
# Non-streaming response
|
||||
if not isinstance(response, StreamingResponse):
|
||||
if event_emitter:
|
||||
try:
|
||||
if isinstance(response, dict) or isinstance(response, JSONResponse):
|
||||
if isinstance(response, list) and len(response) == 1:
|
||||
# If the response is a single-item list, unwrap it #17213
|
||||
response = response[0]
|
||||
user = ctx["user"]
|
||||
metadata = ctx["metadata"]
|
||||
events = ctx["events"]
|
||||
|
||||
if isinstance(response, JSONResponse) and isinstance(
|
||||
response.body, bytes
|
||||
):
|
||||
try:
|
||||
response_data = json.loads(
|
||||
response.body.decode("utf-8", "replace")
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
response_data = {
|
||||
"error": {"detail": "Invalid JSON response"}
|
||||
}
|
||||
else:
|
||||
response_data = response
|
||||
event_emitter = ctx["event_emitter"]
|
||||
|
||||
if "error" in response_data:
|
||||
error = response_data.get("error")
|
||||
response, response_data = get_response_data(response)
|
||||
if response_data is None:
|
||||
return response
|
||||
|
||||
if isinstance(error, dict):
|
||||
error = error.get("detail", error)
|
||||
else:
|
||||
error = str(error)
|
||||
if event_emitter:
|
||||
try:
|
||||
if "error" in response_data:
|
||||
error = response_data.get("error")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"error": {"content": error},
|
||||
if isinstance(error, dict):
|
||||
error = error.get("detail", error)
|
||||
else:
|
||||
error = str(error)
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"error": {"content": error},
|
||||
},
|
||||
)
|
||||
if isinstance(error, str) or isinstance(error, dict):
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:error",
|
||||
"data": {"error": {"content": error}},
|
||||
}
|
||||
)
|
||||
|
||||
if "selected_model_id" in response_data:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": response_data["selected_model_id"],
|
||||
},
|
||||
)
|
||||
|
||||
choices = response_data.get("choices", [])
|
||||
if choices and choices[0].get("message", {}).get("content"):
|
||||
content = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
if content:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": response_data,
|
||||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
|
||||
# Use output from backend if provided (OR-compliant backends)
|
||||
response_output = response_data.get("output")
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"done": True,
|
||||
"content": content,
|
||||
**(
|
||||
{"output": response_output}
|
||||
if response_output
|
||||
else {}
|
||||
),
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
if isinstance(error, str) or isinstance(error, dict):
|
||||
await event_emitter(
|
||||
}
|
||||
)
|
||||
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
**({"output": response_output} if response_output else {}),
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not Users.is_user_active(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
await post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
"type": "chat:message:error",
|
||||
"data": {"error": {"content": error}},
|
||||
}
|
||||
)
|
||||
|
||||
if "selected_model_id" in response_data:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": response_data["selected_model_id"],
|
||||
},
|
||||
)
|
||||
|
||||
choices = response_data.get("choices", [])
|
||||
if choices and choices[0].get("message", {}).get("content"):
|
||||
content = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
if content:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": response_data,
|
||||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
|
||||
# Use output from backend if provided (OR-compliant backends)
|
||||
response_output = response_data.get("output")
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"done": True,
|
||||
"content": content,
|
||||
**(
|
||||
{"output": response_output}
|
||||
if response_output
|
||||
else {}
|
||||
),
|
||||
"title": title,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
**(
|
||||
{"output": response_output}
|
||||
if response_output
|
||||
else {}
|
||||
),
|
||||
"action": "chat",
|
||||
"message": content,
|
||||
"title": title,
|
||||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not Users.is_user_active(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
await post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
webhook_url,
|
||||
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
||||
{
|
||||
"action": "chat",
|
||||
"message": content,
|
||||
"title": title,
|
||||
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
||||
},
|
||||
)
|
||||
await background_tasks_handler(ctx)
|
||||
|
||||
await background_tasks_handler(
|
||||
request, form_data, user, metadata, tasks
|
||||
)
|
||||
response = build_response_object(
|
||||
response, merge_events_into_response(response_data, events)
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"Error occurred while processing request: {e}")
|
||||
pass
|
||||
|
||||
if events and isinstance(events, list):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
|
||||
response_data = {
|
||||
**extra_response,
|
||||
**response_data,
|
||||
}
|
||||
|
||||
if isinstance(response, dict):
|
||||
response = response_data
|
||||
if isinstance(response, JSONResponse):
|
||||
response = JSONResponse(
|
||||
content=response_data,
|
||||
headers=response.headers,
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Error occurred while processing request: {e}")
|
||||
pass
|
||||
|
||||
return response
|
||||
else:
|
||||
if events and isinstance(events, list) and isinstance(response, dict):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
|
||||
response = {
|
||||
**extra_response,
|
||||
**response,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
# Non standard response
|
||||
if not any(
|
||||
content_type in response.headers["Content-Type"]
|
||||
for content_type in ["text/event-stream", "application/x-ndjson"]
|
||||
):
|
||||
return response
|
||||
|
||||
if isinstance(response, dict):
|
||||
response = merge_events_into_response(response_data, events)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def streaming_chat_response_handler(response, ctx):
|
||||
request = ctx["request"]
|
||||
|
||||
form_data = ctx["form_data"]
|
||||
|
||||
user = ctx["user"]
|
||||
model = ctx["model"]
|
||||
|
||||
metadata = ctx["metadata"]
|
||||
events = ctx["events"]
|
||||
|
||||
event_emitter = ctx["event_emitter"]
|
||||
event_caller = ctx["event_caller"]
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
@@ -2830,6 +2861,7 @@ async def process_chat_response(
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
@@ -2837,7 +2869,7 @@ async def process_chat_response(
|
||||
)
|
||||
]
|
||||
|
||||
# Streaming response
|
||||
# Standard streaming response handler
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
model_id = form_data.get("model", "")
|
||||
@@ -4148,9 +4180,9 @@ async def process_chat_response(
|
||||
blocking_code = textwrap.dedent(
|
||||
f"""
|
||||
import builtins
|
||||
|
||||
|
||||
BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES}
|
||||
|
||||
|
||||
_real_import = builtins.__import__
|
||||
def restricted_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name.split('.')[0] in BLOCKED_MODULES:
|
||||
@@ -4160,7 +4192,7 @@ async def process_chat_response(
|
||||
f"Direct import of module {{name}} is restricted."
|
||||
)
|
||||
return _real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
|
||||
builtins.__import__ = restricted_import
|
||||
"""
|
||||
)
|
||||
@@ -4351,9 +4383,7 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
await background_tasks_handler(
|
||||
request, form_data, user, metadata, tasks
|
||||
)
|
||||
await background_tasks_handler(ctx)
|
||||
except asyncio.CancelledError:
|
||||
log.warning("Task was cancelled!")
|
||||
await event_emitter({"type": "chat:tasks:cancel"})
|
||||
@@ -4410,3 +4440,19 @@ async def process_chat_response(
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
|
||||
|
||||
async def process_chat_response(response, ctx):
|
||||
# Non-streaming response
|
||||
if not isinstance(response, StreamingResponse):
|
||||
return await non_streaming_chat_response_handler(response, ctx)
|
||||
|
||||
# Non standard response
|
||||
if not any(
|
||||
content_type in response.headers["Content-Type"]
|
||||
for content_type in ["text/event-stream", "application/x-ndjson"]
|
||||
):
|
||||
return response
|
||||
|
||||
# Streaming response
|
||||
return await streaming_chat_response_handler(response, ctx)
|
||||
|
||||
Reference in New Issue
Block a user