feat: Add tool handling and response modification for non-streaming responses

This commit is contained in:
smonux 2024-10-06 11:09:32 +02:00 committed by smonux (aider)
parent 1f488a0072
commit bfc5f6b2f6

View File

@ -388,6 +388,63 @@ async def get_content_from_response(response) -> Optional[str]:
content = response["choices"][0]["message"]["content"]
return content
def get_tools_body(
body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]:
metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
if not tool_ids:
return body, {}
task_model_id = get_task_model_id(body["model"])
tools = get_tools(
webui_app,
tool_ids,
user,
{
**extra_params,
"__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"],
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()]
tools_args = [ { "type" : "function", "function" : spec } \
for spec in specs ]
body["tools"] = tools_args
return body, tools
def update_body_request(request: Request,
body: dict) -> None:
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
return None
async def handle_nonstreaming_response(request : Request,
response : Response,
tools : dict) -> Response:
response_dict = json.loads(response)
body = json.loads(request._body)
# body["messages"]
while response_dict["choices"][0]["finish_reason"] == "tool_calls":
for tool_call in response_dict["choices"][0]["tool_calls"]:
return response
async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict
@ -601,12 +658,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
}
body["metadata"] = metadata
try:
body, flags = await chat_completion_tools_handler(body, user, extra_params)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
log.exception(e)
body, tools = get_tools_body(body, user, extra_params)
try:
body, flags = await chat_completion_files_handler(body)
@ -651,20 +703,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0:
data_items.append({"citations": citations})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
response = await call_next(request)
if not isinstance(response, StreamingResponse):
update_body_request(request, body)
first_response = await call_next(request)
if not isinstance(first_response, StreamingResponse):
response = await handle_nonstreaming_response(request, response, tools)
return response
content_type = response.headers["Content-Type"]
content_type = first_response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
@ -681,8 +726,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, data_items),
headers=dict(response.headers),
stream_wrapper(first_response.body_iterator, data_items),
headers=dict(first_response.headers),
)
async def _receive(self, body: bytes):