chore : formatting with black

This commit is contained in:
Samuel 2024-11-19 07:58:46 +00:00
parent 0fb9e8d22c
commit 18c092d29e
2 changed files with 189 additions and 110 deletions

View File

@ -571,8 +571,8 @@ async def generate_chat_completion(
del payload["max_tokens"]
# openai.com fails if it gets unknown arguments
if 'native_tool_call' in payload:
del payload['native_tool_call']
if "native_tool_call" in payload:
del payload["native_tool_call"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if is_o1 and payload["messages"][0]["role"] == "system":

View File

@ -350,6 +350,7 @@ async def get_content_from_response(response) -> Optional[str]:
content = response["choices"][0]["message"]["content"]
return content
def get_tools_body(
body: dict, user: UserModel, extra_params: dict, models
) -> tuple[dict, dict]:
@ -361,10 +362,10 @@ def get_tools_body(
return body, {}
task_model_id = get_task_model_id(
body["model"],
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
body["model"],
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
tools = get_tools(
@ -382,14 +383,13 @@ def get_tools_body(
specs = [tool["spec"] for tool in tools.values()]
tools_args = [ { "type" : "function", "function" : spec } \
for spec in specs ]
tools_args = [{"type": "function", "function": spec} for spec in specs]
body["tools"] = tools_args
return body, tools
def update_body_request(request: Request,
body: dict) -> None:
def update_body_request(request: Request, body: dict) -> None:
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
@ -397,37 +397,44 @@ def update_body_request(request: Request,
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
]
return None
def extract_json(binary:bytes)->dict:
def extract_json(binary: bytes) -> dict:
try:
s = binary.decode("utf-8")
return json.loads(s[ s.find("{") : s.rfind("}") + 1 ])
return json.loads(s[s.find("{") : s.rfind("}") + 1])
except Exception as e:
return None
def fill_with_delta(fcall_dict:dict, delta:dict) -> None:
if not 'delta' in delta['choices'][0]:
return
j = delta['choices'][0]['delta']['tool_calls'][0]
if 'id' in j:
fcall_dict["id"] += j["id"] or ''
if 'function' in j:
if 'name' in j['function']:
fcall_dict['function']['name'] += j['function']['name'] or ''
if 'arguments' in j['function']:
fcall_dict['function']['arguments'] += j['function']['arguments'] or ''
async def handle_streaming_response(request: Request, response: Response,
tools: dict,
data_items: list,
call_next: Callable,
user : UserModel) -> StreamingResponse:
def fill_with_delta(fcall_dict: dict, delta: dict) -> None:
if not "delta" in delta["choices"][0]:
return
j = delta["choices"][0]["delta"]["tool_calls"][0]
if "id" in j:
fcall_dict["id"] += j["id"] or ""
if "function" in j:
if "name" in j["function"]:
fcall_dict["function"]["name"] += j["function"]["name"] or ""
if "arguments" in j["function"]:
fcall_dict["function"]["arguments"] += j["function"]["arguments"] or ""
async def handle_streaming_response(
request: Request,
response: Response,
tools: dict,
data_items: list,
call_next: Callable,
user: UserModel,
) -> StreamingResponse:
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
@ -442,84 +449,122 @@ async def handle_streaming_response(request: Request, response: Response,
while True:
peek = await anext(generator)
peek_json = extract_json(peek)
if peek == b'data: [DONE]\n' and len(citations) > 0 :
yield wrap_item(json.dumps({ "citations" : citations}))
if peek == b"data: [DONE]\n" and len(citations) > 0:
yield wrap_item(json.dumps({"citations": citations}))
if peek_json is None or \
not 'choices' in peek_json or \
not 'tool_calls' in peek_json['choices'][0]['delta']:
if (
peek_json is None
or not "choices" in peek_json
or not "tool_calls" in peek_json["choices"][0]["delta"]
):
yield peek
continue
# We reached a tool call so we consume all the messages to assemble it
log.debug("async tool call detected")
tool_calls = [] # id, name, arguments
tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}})
current_index = peek_json['choices'][0]['delta']['tool_calls'][0]['index']
tool_calls = [] # id, name, arguments
tool_calls.append(
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
current_index = peek_json["choices"][0]["delta"]["tool_calls"][0][
"index"
]
fill_with_delta(tool_calls[current_index], peek_json)
async for data in generator:
delta = extract_json(data)
if delta is None or \
not 'choices' in delta or \
not 'tool_calls' in delta['choices'][0]['delta'] or \
delta['choices'][0].get('finish_reason', None) is not None:
if (
delta is None
or not "choices" in delta
or not "tool_calls" in delta["choices"][0]["delta"]
or delta["choices"][0].get("finish_reason", None) is not None
):
continue
i = delta['choices'][0]['delta']['tool_calls'][0]['index']
i = delta["choices"][0]["delta"]["tool_calls"][0]["index"]
if i != current_index:
tool_calls.append({'id':'', 'type': 'function', 'function' : {'name':'', 'arguments':''}})
tool_calls.append(
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
current_index = i
fill_with_delta(tool_calls[i], delta)
log.debug(f"tools to call { [ t['function']['name'] for t in tool_calls] }")
log.debug(
f"tools to call { [ t['function']['name'] for t in tool_calls] }"
)
body["messages"].append( {'role': 'assistant',
'content': None,
'refusal': None,
'tool_calls': tool_calls })
body["messages"].append(
{
"role": "assistant",
"content": None,
"refusal": None,
"tool_calls": tool_calls,
}
)
for tool_call in tool_calls:
tool_function_name = tool_call["function"]["name"]
tool_function_params = extract_json(tool_call["function"]["arguments"])
tool_function_params = extract_json(
tool_call["function"]["arguments"]
)
if not tool_call["function"]["arguments"]:
tool_function_params = {}
else:
tool_function_params = json.loads(tool_call["function"]["arguments"])
tool_function_params = json.loads(
tool_call["function"]["arguments"]
)
log.debug(f"calling {tool_function_name} with params {tool_function_params}")
log.debug(
f"calling {tool_function_name} with params {tool_function_params}"
)
try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
tool_output = await tools[tool_function_name]["callable"](
**tool_function_params
)
except Exception as e:
tool_output = str(e)
# With several tools potentially called together may be this doesn't make sense
if tools[tool_function_name]["file_handler"] and "files" in body.get("metadata", {}):
if tools[tool_function_name][
"file_handler"
] and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
if tools[tool_function_name]["citation"]:
citations.append( {
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
})
citations.append(
{
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
}
)
# Append the tool output to the messages
body["messages"].append({
"role": "tool",
"tool_call_id" : tool_call["id"],
"name": tool_function_name,
"content": tool_output
})
body["messages"].append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"name": tool_function_name,
"content": tool_output,
}
)
# Make another request to the model with the updated context
log.debug("calling the model again with tool output included")
update_body_request(request, body)
response = await generate_chat_completions(form_data = body, user = user )
response = await generate_chat_completions(form_data=body, user=user)
# body_iterator here does not have __anext_() so it has to be done this way
generator = response.body_iterator.__aiter__()
@ -529,15 +574,21 @@ async def handle_streaming_response(request: Request, response: Response,
log.exception(f"Error: {e}")
return StreamingResponse(
stream_wrapper(response.body_iterator, data_items),
headers=dict(response.headers),
)
stream_wrapper(response.body_iterator, data_items),
headers=dict(response.headers),
)
async def handle_nonstreaming_response(request: Request, response: Response,
tools: dict, user: UserModel,
data_items: list, models) -> JSONResponse:
async def handle_nonstreaming_response(
request: Request,
response: Response,
tools: dict,
user: UserModel,
data_items: list,
models,
) -> JSONResponse:
# It only should be one response since we are in the non streaming scenario
content = ''
content = ""
async for data in response.body_iterator:
content += data.decode()
citations = []
@ -549,8 +600,10 @@ async def handle_nonstreaming_response(request: Request, response: Response,
is_ollama = True
is_openai = not is_ollama
while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \
(is_openai and "tool_calls" in response_dict.get("choices", [{}])[0].get("message",{}) ):
while (is_ollama and "tool_calls" in response_dict.get("message", {})) or (
is_openai
and "tool_calls" in response_dict.get("choices", [{}])[0].get("message", {})
):
if is_ollama:
message = response_dict.get("message", {})
tool_calls = message.get("tool_calls", [])
@ -562,46 +615,56 @@ async def handle_nonstreaming_response(request: Request, response: Response,
for tool_call in tool_calls:
# fix for cohere
if 'index' in tool_call:
del tool_call['index']
if "index" in tool_call:
del tool_call["index"]
tool_function_name = tool_call["function"]["name"]
if not tool_call["function"]["arguments"]:
tool_function_params = {}
else:
if is_openai:
tool_function_params = json.loads(tool_call["function"]["arguments"])
tool_function_params = json.loads(
tool_call["function"]["arguments"]
)
if is_ollama:
tool_function_params = tool_call["function"]["arguments"]
try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
tool_output = await tools[tool_function_name]["callable"](
**tool_function_params
)
except Exception as e:
tool_output = str(e)
if tools.get(tool_function_name, {}).get("citation", False):
citations.append( {
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
})
citations.append(
{
"source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output],
"metadata": [{"source": tool_function_name}],
}
)
# Append the tool output to the messages
body["messages"].append({
"role": "tool",
"tool_call_id" : tool_call.get("id",""),
"name": tool_function_name,
"content": tool_output
})
body["messages"].append(
{
"role": "tool",
"tool_call_id": tool_call.get("id", ""),
"name": tool_function_name,
"content": tool_output,
}
)
# Make another request to the model with the updated context
update_body_request(request, body)
response_dict = await generate_chat_completions(form_data = body, user = user, as_openai = is_openai)
response_dict = await generate_chat_completions(
form_data=body, user=user, as_openai=is_openai
)
#FIXME: is it possible to handle citations?
return JSONResponse(content = response_dict)
# FIXME: is it possible to handle citations?
return JSONResponse(content=response_dict)
def get_task_model_id(
@ -621,7 +684,7 @@ def get_task_model_id(
async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict, models:dict
body: dict, user: UserModel, extra_params: dict, models: dict
) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
@ -889,17 +952,21 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
body, tools = get_tools_body(body, user, extra_params, models)
if model["owned_by"] == "ollama" and \
body["metadata"]["native_tool_call"] and \
body.get("stream", False):
log.info("Ollama models don't support function calling in streaming yet. forcing native_tool_call to False")
if (
model["owned_by"] == "ollama"
and body["metadata"]["native_tool_call"]
and body.get("stream", False)
):
log.info(
"Ollama models don't support function calling in streaming yet. forcing native_tool_call to False"
)
body["metadata"]["native_tool_call"] = False
if not body["metadata"]["native_tool_call"]:
del body["tools"] # we won't use those
del body["tools"] # we won't use those
try:
body, flags = await chat_completion_tools_handler(
body, user, extra_params, models
body, user, extra_params, models
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
@ -960,7 +1027,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
return response
return response
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
@ -978,8 +1045,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
else:
if not body.get("stream", False):
return await handle_nonstreaming_response(request, response, tools, user, data_items, models)
return await handle_streaming_response(request, response, tools, data_items, call_next, user)
return await handle_nonstreaming_response(
request, response, tools, user, data_items, models
)
return await handle_streaming_response(
request, response, tools, data_items, call_next, user
)
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
@ -1423,8 +1494,10 @@ async def get_base_models(user=Depends(get_admin_user)):
@app.post("/api/chat/completions")
async def generate_chat_completions(
form_data: dict, user=Depends(get_verified_user),
bypass_filter: bool = False, as_openai: bool = False
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: bool = False,
as_openai: bool = False,
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
@ -1530,11 +1603,17 @@ async def generate_chat_completions(
if form_data.stream:
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response) if as_openai else response,
(
convert_streaming_response_ollama_to_openai(response)
if as_openai
else response
),
headers=dict(response.headers),
)
else:
return convert_response_ollama_to_openai(response) if as_openai else response
return (
convert_response_ollama_to_openai(response) if as_openai else response
)
else:
return await generate_openai_chat_completion(
form_data, user=user, bypass_filter=bypass_filter