Feat: models are aware of attached files when they choose a tool to call

This commit is contained in:
zzzevaka 2025-05-31 22:09:12 +02:00
parent 235489cfc5
commit 833cf5130c
3 changed files with 44 additions and 33 deletions

View File

@ -1558,7 +1558,12 @@ The format for the JSON response is strictly:
{"name": "toolName1", "parameters": {"key1": "value1"}},
{"name": "toolName2", "parameters": {"key2": "value2"}}
]
}"""
}
<context>
{{CONTEXT}}
</context>
"""
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).

View File

@ -99,7 +99,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
request: Request, body: dict, extra_params: dict, user: UserModel, models: dict, tools: dict, context: str,
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@ -156,7 +156,7 @@ async def chat_completion_tools_handler(
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
tools_function_calling_prompt = tools_function_calling_generation_template(
template, tools_specs
template, tools_specs, context,
)
payload = get_tools_function_calling_payload(
body["messages"], task_model_id, tools_function_calling_prompt
@ -716,6 +716,31 @@ def apply_params_to_form_data(form_data, model):
return form_data
def create_context_string_from_sources(sources: list[dict[str, Any]]) -> str:
context_string = ""
citation_idx = {}
for source in sources:
if "document" in source:
for doc_context, doc_meta in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
citation_id = (
doc_meta.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
if citation_id not in citation_idx:
citation_idx[citation_id] = len(citation_idx) + 1
context_string += (
f'<source id="{citation_idx[citation_id]}"'
+ (f' name="{source_name}"' if source_name else "")
+ f">{doc_context}</source>\n"
)
return context_string.strip()
async def process_chat_payload(request, form_data, user, metadata, model):
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@ -901,6 +926,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server,
}
try:
form_data, flags = await chat_completion_files_handler(request, form_data, user)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
@ -912,44 +943,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
else:
# If the function calling is not native, then call the tools function calling handler
try:
context_string = create_context_string_from_sources(sources)
form_data, flags = await chat_completion_tools_handler(
request, form_data, extra_params, user, models, tools_dict
request, form_data, extra_params, user, models, tools_dict, context_string,
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
try:
form_data, flags = await chat_completion_files_handler(request, form_data, user)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
citation_idx = {}
for source in sources:
if "document" in source:
for doc_context, doc_meta in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
citation_id = (
doc_meta.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
if citation_id not in citation_idx:
citation_idx[citation_id] = len(citation_idx) + 1
context_string += (
f'<source id="{citation_idx[citation_id]}"'
+ (f' name="{source_name}"' if source_name else "")
+ f">{doc_context}</source>\n"
)
context_string = context_string.strip()
context_string = create_context_string_from_sources(sources)
prompt = get_last_user_message(form_data["messages"])
if prompt is None:

View File

@ -336,6 +336,7 @@ def moa_response_generation_template(
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
def tools_function_calling_generation_template(template: str, tools_specs: str, context: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
template = template.replace("{{CONTEXT}}", context)
return template