mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 21:13:59 +00:00
refac
This commit is contained in:
parent
c1823b4b73
commit
c4946d42e0
@ -283,21 +283,6 @@ def get_filter_function_ids(model):
|
|||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
async def get_content_from_response(response) -> Optional[str]:
|
|
||||||
content = None
|
|
||||||
if hasattr(response, "body_iterator"):
|
|
||||||
async for chunk in response.body_iterator:
|
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
|
||||||
content = data["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
# Cleanup any remaining background tasks if necessary
|
|
||||||
if response.background is not None:
|
|
||||||
await response.background()
|
|
||||||
else:
|
|
||||||
content = response["choices"][0]["message"]["content"]
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_call_payload(messages, task_model_id, content):
|
def get_tool_call_payload(messages, task_model_id, content):
|
||||||
user_message = get_last_user_message(messages)
|
user_message = get_last_user_message(messages)
|
||||||
history = "\n".join(
|
history = "\n".join(
|
||||||
@ -403,8 +388,8 @@ def get_tool_with_custom_params(
|
|||||||
|
|
||||||
|
|
||||||
# Mutation on extra_params
|
# Mutation on extra_params
|
||||||
def get_configured_tools(
|
def get_tools(
|
||||||
tool_ids: list[str], extra_params: dict, user: UserModel
|
tool_ids: list[str], user: UserModel, extra_params: dict
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
tools = {}
|
tools = {}
|
||||||
for tool_id in tool_ids:
|
for tool_id in tool_ids:
|
||||||
@ -420,6 +405,7 @@ def get_configured_tools(
|
|||||||
extra_params["__id__"] = tool_id
|
extra_params["__id__"] = tool_id
|
||||||
has_citation = hasattr(module, "citation") and module.citation
|
has_citation = hasattr(module, "citation") and module.citation
|
||||||
handles_files = hasattr(module, "file_handler") and module.file_handler
|
handles_files = hasattr(module, "file_handler") and module.file_handler
|
||||||
|
|
||||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||||
module.valves = module.Valves(**valves)
|
module.valves = module.Valves(**valves)
|
||||||
@ -459,35 +445,51 @@ def get_configured_tools(
|
|||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
async def get_content_from_response(response) -> Optional[str]:
|
||||||
|
content = None
|
||||||
|
if hasattr(response, "body_iterator"):
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
data = json.loads(chunk.decode("utf-8"))
|
||||||
|
content = data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# Cleanup any remaining background tasks if necessary
|
||||||
|
if response.background is not None:
|
||||||
|
await response.background()
|
||||||
|
else:
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
body: dict, user: UserModel, extra_params: dict
|
body: dict, user: UserModel, extra_params: dict
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
skip_files = False
|
skip_files = False
|
||||||
contexts = []
|
contexts = []
|
||||||
citations = []
|
citations = []
|
||||||
task_model_id = get_task_model_id(body["model"])
|
|
||||||
|
|
||||||
|
task_model_id = get_task_model_id(body["model"])
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
tool_ids = body.pop("tool_ids", None)
|
tool_ids = body.pop("tool_ids", None)
|
||||||
if not tool_ids:
|
if not tool_ids:
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
log.debug(f"{tool_ids=}")
|
log.debug(f"{tool_ids=}")
|
||||||
|
|
||||||
custom_params = {
|
custom_params = {
|
||||||
**extra_params,
|
**extra_params,
|
||||||
"__model__": app.state.MODELS[task_model_id],
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
"__messages__": body["messages"],
|
"__messages__": body["messages"],
|
||||||
"__files__": body.get("files", []),
|
"__files__": body.get("files", []),
|
||||||
}
|
}
|
||||||
configured_tools = get_configured_tools(tool_ids, custom_params, user)
|
tools = get_tools(tool_ids, user, custom_params)
|
||||||
|
log.info(f"{tools=}")
|
||||||
|
|
||||||
log.info(f"{configured_tools=}")
|
specs = [tool["spec"] for tool in tools.values()]
|
||||||
|
|
||||||
specs = [tool["spec"] for tool in configured_tools.values()]
|
|
||||||
tools_specs = json.dumps(specs)
|
tools_specs = json.dumps(specs)
|
||||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
content = tool_calling_generation_template(template, tools_specs)
|
content = tool_calling_generation_template(template, tools_specs)
|
||||||
payload = get_tool_call_payload(body["messages"], task_model_id, content)
|
payload = get_tool_call_payload(body["messages"], task_model_id, content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user)
|
payload = filter_pipeline(payload, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -503,16 +505,18 @@ async def chat_completion_tools_handler(
|
|||||||
|
|
||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
tool_name = result.get("name", None)
|
tool_name = result.get("name", None)
|
||||||
if tool_name not in configured_tools:
|
if tool_name not in tools:
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
tool_params = result.get("parameters", {})
|
tool_params = result.get("parameters", {})
|
||||||
toolkit_id = configured_tools[tool_name]["toolkit_id"]
|
toolkit_id = tools[tool_name]["toolkit_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_output = await configured_tools[tool_name]["callable"](**tool_params)
|
tool_output = await tools[tool_name]["callable"](**tool_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_output = str(e)
|
tool_output = str(e)
|
||||||
if configured_tools[tool_name]["citation"]:
|
|
||||||
|
if tools[tool_name]["citation"]:
|
||||||
citations.append(
|
citations.append(
|
||||||
{
|
{
|
||||||
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
||||||
@ -520,7 +524,7 @@ async def chat_completion_tools_handler(
|
|||||||
"metadata": [{"source": tool_name}],
|
"metadata": [{"source": tool_name}],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if configured_tools[tool_name]["file_handler"]:
|
if tools[tool_name]["file_handler"]:
|
||||||
skip_files = True
|
skip_files = True
|
||||||
|
|
||||||
if isinstance(tool_output, str):
|
if isinstance(tool_output, str):
|
||||||
|
Loading…
Reference in New Issue
Block a user