This commit is contained in:
Timothy J. Baek 2024-06-11 11:15:43 -07:00
parent e4fe1fff97
commit 9d16dd997a

View File

@ -168,11 +168,25 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
async def get_function_call_response(prompt, tool_id, template, task_model_id, user): async def get_function_call_response(messages, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs) content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role']}: {message['content']}"
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = { payload = {
"model": task_model_id, "model": task_model_id,
"messages": [ "messages": [
@ -300,16 +314,16 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
): ):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = "" context = ""
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
if "tool_ids" in data: if "tool_ids" in data:
print(data["tool_ids"]) print(data["tool_ids"])
prompt = get_last_user_message(data["messages"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
response = await get_function_call_response( response = await get_function_call_response(
prompt=prompt, messages=data["messages"],
tool_id=tool_id, tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id, task_model_id=task_model_id,
@ -839,7 +853,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response( return await get_function_call_response(
form_data["prompt"], form_data["tool_id"], template, model_id, user form_data["messages"], form_data["tool_id"], template, model_id, user
) )