fix: now streaming works

This commit is contained in:
smonux 2024-10-12 20:01:43 +02:00
parent c19c7b56cf
commit 712f82e4dd

View File

@ -11,7 +11,7 @@ import uuid
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional, Callable
import aiohttp import aiohttp
import requests import requests
@ -432,58 +432,106 @@ def update_body_request(request: Request,
] ]
return None return None
def extract_json(binary:bytes)->dict:
s = binary.decode("utf-8")
return json.loads(s[ s.find("{") : s.rfind("}") + 1 ])
def fill_with_delta(fcall_dict:dict, delta:dict) -> None:
if not 'delta' in delta['choices'][0]:
return
j = delta['choices'][0]['delta']['tool_calls'][0]
if 'id' in j:
fcall_dict["id"] += j["id"]
if 'function' in j:
if 'name' in j['function']:
fcall_dict['function']['name'] += j['function']['name']
if 'arguments' in j['function']:
fcall_dict['function']['arguments'] += j['function']['arguments']
async def handle_streaming_response(request: Request, response: Response, async def handle_streaming_response(request: Request, response: Response,
tools: dict, tools: dict,
data_items: list, data_items: list,
call_next) -> StreamingResponse: call_next: Callable,
user : UserModel) -> StreamingResponse:
"""content_type = response.headers["Content-Type"] content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
return response"""
log.debug("smonux 22")
def wrap_item(item): def wrap_item(item):
#return f"data: {item}\n\n" if is_openai else f"{item}\n" return f"data: {item}\n\n" if is_openai else f"{item}\n"
return f"data: {item}\n\n"
async def stream_wrapper(original_generator, data_items): async def stream_wrapper(original_generator, data_items):
#for item in data_items: for item in data_items:
# yield wrap_item(json.dumps(item)) yield wrap_item(json.dumps(item))
while True:
full_response = "" body = json.loads(request._body)
async for data in original_generator: generator = original_generator
full_response += data.decode('utf-8') if isinstance(data, bytes) else data try:
yield data while True:
log.debug(f"smonux 24 {full_response}") peek = await generator.__anext__()
if peek in (b'\n', b'data: [DONE]'):
yield peek
continue
peek_json = extract_json(peek)
if not 'tool_calls' in peek_json['choices'][0]['delta']:
yield peek
continue
full_response_dict = json.loads(full_response) # We reached a tool call we consume all the messages to assemble it
if full_response_dict["choices"][0]["finish_reason"] != "tool_calls": log.debug("async tool call detected")
break tool_calls = [] # id, name, arguments
tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}})
body["messages"].append(response_dict["choices"][0]["message"]) current_index = peek_json['choices'][0]['delta']['tool_calls'][0]['index']
for tool_call in full_response_dict["choices"][0].get("tool_calls", []): fill_with_delta(tool_calls[current_index], peek_json)
tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"])
try: async for data in generator:
log.debug(f"smonux 24 {tool_function_name}") log.debug(f"smonux 24 {data=}")
tool_output = await tools[tool_function_name]["callable"](**tool_function_params) if data in (b'\n', b'data: [DONE]\n'):
except Exception as e: continue
delta = extract_json(data)
if delta['choices'][0]['finish_reason'] is not None:
continue
i = delta['choices'][0]['delta']['tool_calls'][0]['index']
if i != current_index:
tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}})
current_index = i
fill_with_delta(tool_calls[i], delta)
log.debug(f"tools to call { [ t['function']['name'] for t in tool_calls] }")
body["messages"].append( {'role': 'assistant',
'content': None,
'refusal': None,
'tool_calls': tool_calls })
for tool_call in tool_calls:
tool_function_name = tool_call["function"]["name"]
tool_function_params = json.loads(tool_call["function"]["arguments"])
try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
except Exception as e:
tool_output = str(e) tool_output = str(e)
# Append the tool output to the messages # Append the tool output to the messages
body["messages"].append({ body["messages"].append({
"role": "tool", "role": "tool",
"tool_call_id" : tool_call["id"], "tool_call_id" : tool_call["id"],
"name": tool_function_name, "name": tool_function_name,
"content": tool_output "content": tool_output
}) })
# Make another request to the model with the updated context
log.debug("calling the model again with tool output included")
update_body_request(request, body) update_body_request(request, body)
response = await call_next(request) response = await generate_chat_completions(form_data = body, user = user )
original_generator = response.body_iterator generator = response.body_iterator.__aiter__()
except Exception as e:
log.exception(f"Error: {e}")
return StreamingResponse( return StreamingResponse(
stream_wrapper(response.body_iterator, data_items), stream_wrapper(response.body_iterator, data_items),
@ -491,7 +539,7 @@ async def handle_streaming_response(request: Request, response: Response,
) )
async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse: async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse:
# It only should be one response # It only should be one response ince we are in the async scenario
async for data in response.body_iterator: async for data in response.body_iterator:
content = data content = data
response_dict = json.loads(content) response_dict = json.loads(content)
@ -786,8 +834,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if body.get("stream", False) is False: if body.get("stream", False) is False:
return await handle_nonstreaming_response(request, first_response, tools, user) return await handle_nonstreaming_response(request, first_response, tools, user)
log.debug("smonux 20")
return await handle_streaming_response(request, first_response, tools, data_items, call_next) return await handle_streaming_response(request, first_response, tools, data_items, call_next, user)
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}