diff --git a/backend/main.py b/backend/main.py index 74fc4aa5d..187a0b720 100644 --- a/backend/main.py +++ b/backend/main.py @@ -168,11 +168,25 @@ app.state.MODELS = {} origins = ["*"] -async def get_function_call_response(prompt, tool_id, template, task_model_id, user): +async def get_function_call_response(messages, tool_id, template, task_model_id, user): tool = Tools.get_tool_by_id(tool_id) tools_specs = json.dumps(tool.specs, indent=2) content = tools_function_calling_generation_template(template, tools_specs) + user_message = get_last_user_message(messages) + prompt = ( + "History:\n" + + "\n".join( + [ + f"{message['role']}: {message['content']}" + for message in messages[::-1][:4] + ] + ) + + f"\nQuery: {user_message}" + ) + + print(prompt) + payload = { "model": task_model_id, "messages": [ @@ -300,16 +314,16 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL + prompt = get_last_user_message(data["messages"]) context = "" # If tool_ids field is present, call the functions if "tool_ids" in data: print(data["tool_ids"]) - prompt = get_last_user_message(data["messages"]) for tool_id in data["tool_ids"]: print(tool_id) response = await get_function_call_response( - prompt=prompt, + messages=data["messages"], tool_id=tool_id, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, task_model_id=task_model_id, @@ -839,7 +853,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE return await get_function_call_response( - form_data["prompt"], form_data["tool_id"], template, model_id, user + form_data["messages"], form_data["tool_id"], template, model_id, user )