mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
Feat: models are aware of attached files when they choose a tool to call
This commit is contained in:
parent
235489cfc5
commit
833cf5130c
@ -1558,7 +1558,12 @@ The format for the JSON response is strictly:
|
||||
{"name": "toolName1", "parameters": {"key1": "value1"}},
|
||||
{"name": "toolName2", "parameters": {"key2": "value2"}}
|
||||
]
|
||||
}"""
|
||||
}
|
||||
|
||||
<context>
|
||||
{{CONTEXT}}
|
||||
</context>
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
@ -99,7 +99,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
|
||||
request: Request, body: dict, extra_params: dict, user: UserModel, models: dict, tools: dict, context: str,
|
||||
) -> tuple[dict, dict]:
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
@ -156,7 +156,7 @@ async def chat_completion_tools_handler(
|
||||
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
tools_function_calling_prompt = tools_function_calling_generation_template(
|
||||
template, tools_specs
|
||||
template, tools_specs, context,
|
||||
)
|
||||
payload = get_tools_function_calling_payload(
|
||||
body["messages"], task_model_id, tools_function_calling_prompt
|
||||
@ -716,6 +716,31 @@ def apply_params_to_form_data(form_data, model):
|
||||
return form_data
|
||||
|
||||
|
||||
def create_context_string_from_sources(sources: list[dict[str, Any]]) -> str:
|
||||
context_string = ""
|
||||
citation_idx = {}
|
||||
for source in sources:
|
||||
if "document" in source:
|
||||
for doc_context, doc_meta in zip(
|
||||
source["document"], source["metadata"]
|
||||
):
|
||||
source_name = source.get("source", {}).get("name", None)
|
||||
citation_id = (
|
||||
doc_meta.get("source", None)
|
||||
or source.get("source", {}).get("id", None)
|
||||
or "N/A"
|
||||
)
|
||||
if citation_id not in citation_idx:
|
||||
citation_idx[citation_id] = len(citation_idx) + 1
|
||||
context_string += (
|
||||
f'<source id="{citation_idx[citation_id]}"'
|
||||
+ (f' name="{source_name}"' if source_name else "")
|
||||
+ f">{doc_context}</source>\n"
|
||||
)
|
||||
|
||||
return context_string.strip()
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
@ -901,6 +926,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
"server": tool_server,
|
||||
}
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
if tools_dict:
|
||||
if metadata.get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
@ -912,44 +943,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
try:
|
||||
context_string = create_context_string_from_sources(sources)
|
||||
form_data, flags = await chat_completion_tools_handler(
|
||||
request, form_data, extra_params, user, models, tools_dict
|
||||
request, form_data, extra_params, user, models, tools_dict, context_string,
|
||||
)
|
||||
sources.extend(flags.get("sources", []))
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(sources) > 0:
|
||||
context_string = ""
|
||||
citation_idx = {}
|
||||
for source in sources:
|
||||
if "document" in source:
|
||||
for doc_context, doc_meta in zip(
|
||||
source["document"], source["metadata"]
|
||||
):
|
||||
source_name = source.get("source", {}).get("name", None)
|
||||
citation_id = (
|
||||
doc_meta.get("source", None)
|
||||
or source.get("source", {}).get("id", None)
|
||||
or "N/A"
|
||||
)
|
||||
if citation_id not in citation_idx:
|
||||
citation_idx[citation_id] = len(citation_idx) + 1
|
||||
context_string += (
|
||||
f'<source id="{citation_idx[citation_id]}"'
|
||||
+ (f' name="{source_name}"' if source_name else "")
|
||||
+ f">{doc_context}</source>\n"
|
||||
)
|
||||
|
||||
context_string = context_string.strip()
|
||||
context_string = create_context_string_from_sources(sources)
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
|
||||
if prompt is None:
|
||||
|
@ -336,6 +336,7 @@ def moa_response_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||
def tools_function_calling_generation_template(template: str, tools_specs: str, context: str) -> str:
|
||||
template = template.replace("{{TOOLS}}", tools_specs)
|
||||
template = template.replace("{{CONTEXT}}", context)
|
||||
return template
|
||||
|
Loading…
Reference in New Issue
Block a user