feat: adding citations to non async case, first steps

This commit is contained in:
smonux 2024-10-13 18:57:49 +02:00
parent 3683334b2a
commit 8c920c99c4

View File

@ -523,7 +523,6 @@ async def handle_streaming_response(request: Request, response: Response,
del body["metadata"]["files"] del body["metadata"]["files"]
if tools[tool_function_name]["citation"]: if tools[tool_function_name]["citation"]:
log.debug("smonux CITATION")
citations.append( { citations.append( {
"source": { "source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
@ -551,20 +550,24 @@ async def handle_streaming_response(request: Request, response: Response,
except Exception as e: except Exception as e:
log.exception(f"Error: {e}") log.exception(f"Error: {e}")
return StreamingResponse( return StreamingResponse(
stream_wrapper(response.body_iterator, data_items), stream_wrapper(response.body_iterator, data_items),
headers=dict(response.headers), headers=dict(response.headers),
) )
async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse: async def handle_nonstreaming_response(request: Request, response: Response,
# It only should be one response ince we are in the async scenario tools: dict, user: UserModel, data_items: list) -> JSONResponse:
# It only should be one response ince we are in the non streaming scenario
async for data in response.body_iterator: async for data in response.body_iterator:
content = data content = data
citations = []
response_dict = json.loads(content) response_dict = json.loads(content)
body = json.loads(request._body) body = json.loads(request._body)
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
while response_dict["choices"][0]["finish_reason"] == "tool_calls": while response_dict["choices"][0]["finish_reason"] == "tool_calls":
body["messages"].append(response_dict["choices"][0]["message"]) body["messages"].append(response_dict["choices"][0]["message"])
tool_calls = response_dict["choices"][0]["message"].get("tool_calls", []) tool_calls = response_dict["choices"][0]["message"].get("tool_calls", [])
@ -577,6 +580,15 @@ async def handle_nonstreaming_response(request: Request, response: Response, too
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
if tools[tool_function_name]["citation"]:
citations.append( {
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
})
# Append the tool output to the messages # Append the tool output to the messages
body["messages"].append({ body["messages"].append({
"role": "tool", "role": "tool",
@ -589,6 +601,8 @@ async def handle_nonstreaming_response(request: Request, response: Response, too
update_body_request(request, body) update_body_request(request, body)
response_dict = await generate_chat_completions(form_data = body, user = user ) response_dict = await generate_chat_completions(form_data = body, user = user )
#FIXME: handle citations and data_items (Streaming Response)
return JSONResponse(content = response_dict) return JSONResponse(content = response_dict)
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
@ -852,7 +866,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
first_response = await call_next(request) first_response = await call_next(request)
if body.get("stream", False) is False: if body.get("stream", False) is False:
return await handle_nonstreaming_response(request, first_response, tools, user) return await handle_nonstreaming_response(request, first_response, tools, user, data_items)
return await handle_streaming_response(request, first_response, tools, data_items, call_next, user) return await handle_streaming_response(request, first_response, tools, data_items, call_next, user)