mirror of
https://github.com/open-webui/open-webui
synced 2024-11-26 22:09:51 +00:00
fix: now streaming works
This commit is contained in:
parent
c19c7b56cf
commit
712f82e4dd
@ -11,7 +11,7 @@ import uuid
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional, Callable
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -432,58 +432,106 @@ def update_body_request(request: Request,
|
|||||||
]
|
]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def extract_json(binary:bytes)->dict:
|
||||||
|
s = binary.decode("utf-8")
|
||||||
|
return json.loads(s[ s.find("{") : s.rfind("}") + 1 ])
|
||||||
|
|
||||||
|
def fill_with_delta(fcall_dict:dict, delta:dict) -> None:
|
||||||
|
if not 'delta' in delta['choices'][0]:
|
||||||
|
return
|
||||||
|
j = delta['choices'][0]['delta']['tool_calls'][0]
|
||||||
|
if 'id' in j:
|
||||||
|
fcall_dict["id"] += j["id"]
|
||||||
|
if 'function' in j:
|
||||||
|
if 'name' in j['function']:
|
||||||
|
fcall_dict['function']['name'] += j['function']['name']
|
||||||
|
if 'arguments' in j['function']:
|
||||||
|
fcall_dict['function']['arguments'] += j['function']['arguments']
|
||||||
|
|
||||||
async def handle_streaming_response(request: Request, response: Response,
|
async def handle_streaming_response(request: Request, response: Response,
|
||||||
tools: dict,
|
tools: dict,
|
||||||
data_items: list,
|
data_items: list,
|
||||||
call_next) -> StreamingResponse:
|
call_next: Callable,
|
||||||
|
user : UserModel) -> StreamingResponse:
|
||||||
|
|
||||||
"""content_type = response.headers["Content-Type"]
|
content_type = response.headers["Content-Type"]
|
||||||
is_openai = "text/event-stream" in content_type
|
is_openai = "text/event-stream" in content_type
|
||||||
is_ollama = "application/x-ndjson" in content_type
|
is_ollama = "application/x-ndjson" in content_type
|
||||||
if not is_openai and not is_ollama:
|
|
||||||
return response"""
|
|
||||||
|
|
||||||
log.debug("smonux 22")
|
|
||||||
def wrap_item(item):
|
def wrap_item(item):
|
||||||
#return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||||
return f"data: {item}\n\n"
|
|
||||||
|
|
||||||
async def stream_wrapper(original_generator, data_items):
|
async def stream_wrapper(original_generator, data_items):
|
||||||
#for item in data_items:
|
for item in data_items:
|
||||||
# yield wrap_item(json.dumps(item))
|
yield wrap_item(json.dumps(item))
|
||||||
while True:
|
|
||||||
full_response = ""
|
body = json.loads(request._body)
|
||||||
async for data in original_generator:
|
generator = original_generator
|
||||||
full_response += data.decode('utf-8') if isinstance(data, bytes) else data
|
try:
|
||||||
yield data
|
while True:
|
||||||
log.debug(f"smonux 24 {full_response}")
|
peek = await generator.__anext__()
|
||||||
|
if peek in (b'\n', b'data: [DONE]'):
|
||||||
|
yield peek
|
||||||
|
continue
|
||||||
|
peek_json = extract_json(peek)
|
||||||
|
if not 'tool_calls' in peek_json['choices'][0]['delta']:
|
||||||
|
yield peek
|
||||||
|
continue
|
||||||
|
|
||||||
full_response_dict = json.loads(full_response)
|
# We reached a tool call we consume all the messages to assemble it
|
||||||
if full_response_dict["choices"][0]["finish_reason"] != "tool_calls":
|
log.debug("async tool call detected")
|
||||||
break
|
tool_calls = [] # id, name, arguments
|
||||||
|
tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}})
|
||||||
body["messages"].append(response_dict["choices"][0]["message"])
|
current_index = peek_json['choices'][0]['delta']['tool_calls'][0]['index']
|
||||||
for tool_call in full_response_dict["choices"][0].get("tool_calls", []):
|
fill_with_delta(tool_calls[current_index], peek_json)
|
||||||
tool_function_name = tool_call["function"]["name"]
|
|
||||||
tool_function_params = json.loads(tool_call["function"]["arguments"])
|
|
||||||
|
|
||||||
try:
|
async for data in generator:
|
||||||
log.debug(f"smonux 24 {tool_function_name}")
|
log.debug(f"smonux 24 {data=}")
|
||||||
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
|
if data in (b'\n', b'data: [DONE]\n'):
|
||||||
except Exception as e:
|
continue
|
||||||
|
delta = extract_json(data)
|
||||||
|
if delta['choices'][0]['finish_reason'] is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
i = delta['choices'][0]['delta']['tool_calls'][0]['index']
|
||||||
|
if i != current_index:
|
||||||
|
tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}})
|
||||||
|
current_index = i
|
||||||
|
|
||||||
|
fill_with_delta(tool_calls[i], delta)
|
||||||
|
|
||||||
|
log.debug(f"tools to call { [ t['function']['name'] for t in tool_calls] }")
|
||||||
|
|
||||||
|
body["messages"].append( {'role': 'assistant',
|
||||||
|
'content': None,
|
||||||
|
'refusal': None,
|
||||||
|
'tool_calls': tool_calls })
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_function_name = tool_call["function"]["name"]
|
||||||
|
tool_function_params = json.loads(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
|
||||||
|
except Exception as e:
|
||||||
tool_output = str(e)
|
tool_output = str(e)
|
||||||
|
|
||||||
# Append the tool output to the messages
|
# Append the tool output to the messages
|
||||||
body["messages"].append({
|
body["messages"].append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id" : tool_call["id"],
|
"tool_call_id" : tool_call["id"],
|
||||||
"name": tool_function_name,
|
"name": tool_function_name,
|
||||||
"content": tool_output
|
"content": tool_output
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Make another request to the model with the updated context
|
||||||
|
log.debug("calling the model again with tool output included")
|
||||||
update_body_request(request, body)
|
update_body_request(request, body)
|
||||||
response = await call_next(request)
|
response = await generate_chat_completions(form_data = body, user = user )
|
||||||
original_generator = response.body_iterator
|
generator = response.body_iterator.__aiter__()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error: {e}")
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_wrapper(response.body_iterator, data_items),
|
stream_wrapper(response.body_iterator, data_items),
|
||||||
@ -491,7 +539,7 @@ async def handle_streaming_response(request: Request, response: Response,
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse:
|
async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel) -> JSONResponse:
|
||||||
# It only should be one response
|
# It only should be one response ince we are in the async scenario
|
||||||
async for data in response.body_iterator:
|
async for data in response.body_iterator:
|
||||||
content = data
|
content = data
|
||||||
response_dict = json.loads(content)
|
response_dict = json.loads(content)
|
||||||
@ -786,8 +834,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
if body.get("stream", False) is False:
|
if body.get("stream", False) is False:
|
||||||
return await handle_nonstreaming_response(request, first_response, tools, user)
|
return await handle_nonstreaming_response(request, first_response, tools, user)
|
||||||
log.debug("smonux 20")
|
|
||||||
return await handle_streaming_response(request, first_response, tools, data_items, call_next)
|
return await handle_streaming_response(request, first_response, tools, data_items, call_next, user)
|
||||||
|
|
||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
return {"type": "http.request", "body": body, "more_body": False}
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
Loading…
Reference in New Issue
Block a user