factor out get_function_calling_payload

This commit is contained in:
Michael Poluektov 2024-08-11 09:05:22 +01:00
parent ff9d899f9c
commit e86688284a

View File

@ -322,6 +322,26 @@ async def call_tool_from_completion(
return None
def get_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
)
prompt = f"History:\n{history}\nQuery: {user_message}"
return {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user, extra_params
) -> tuple[Optional[str], Optional[dict], bool]:
@ -331,30 +351,7 @@ async def get_function_call_response(
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
payload = get_function_calling_payload(messages, task_model_id, content)
try:
payload = filter_pipeline(payload, user)