mirror of
https://github.com/open-webui/open-webui
synced 2025-05-17 20:05:08 +00:00
refactor get_function_call_response
This commit is contained in:
parent
9fb70969d7
commit
a68b918cbb
118
backend/main.py
118
backend/main.py
@ -297,6 +297,30 @@ async def get_content_from_response(response) -> Optional[str]:
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
async def call_tool_from_completion(
|
||||||
|
result: dict, extra_params: dict, toolkit_module
|
||||||
|
) -> Optional[str]:
|
||||||
|
if "name" not in result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool = getattr(toolkit_module, result["name"])
|
||||||
|
try:
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(tool)
|
||||||
|
params = result["parameters"]
|
||||||
|
for key, value in extra_params.items():
|
||||||
|
if key in sig.parameters:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(tool):
|
||||||
|
return await tool(**params)
|
||||||
|
else:
|
||||||
|
return tool(**params)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(
|
async def get_function_call_response(
|
||||||
messages,
|
messages,
|
||||||
files,
|
files,
|
||||||
@ -306,7 +330,7 @@ async def get_function_call_response(
|
|||||||
user,
|
user,
|
||||||
__event_emitter__=None,
|
__event_emitter__=None,
|
||||||
__event_call__=None,
|
__event_call__=None,
|
||||||
):
|
) -> tuple[Optional[str], Optional[dict], bool]:
|
||||||
tool = Tools.get_tool_by_id(tool_id)
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
return None, None, False
|
return None, None, False
|
||||||
@ -343,65 +367,12 @@ async def get_function_call_response(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
model = app.state.MODELS[task_model_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await generate_chat_completions(form_data=payload, user=user)
|
|
||||||
content = await get_content_from_response(response)
|
|
||||||
|
|
||||||
if content is None:
|
|
||||||
return None, None, False
|
|
||||||
|
|
||||||
# Parse the function response
|
|
||||||
print(f"content: {content}")
|
|
||||||
result = json.loads(content)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
if "name" not in result:
|
|
||||||
return None, None, False
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
if tool_id in webui_app.state.TOOLS:
|
if tool_id in webui_app.state.TOOLS:
|
||||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||||
else:
|
else:
|
||||||
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
||||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||||
|
|
||||||
file_handler = False
|
|
||||||
# check if toolkit_module has file_handler self variable
|
|
||||||
if hasattr(toolkit_module, "file_handler"):
|
|
||||||
file_handler = True
|
|
||||||
print("file_handler: ", file_handler)
|
|
||||||
|
|
||||||
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
|
||||||
valves = Tools.get_tool_valves_by_id(tool_id)
|
|
||||||
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
|
|
||||||
|
|
||||||
function = getattr(toolkit_module, result["name"])
|
|
||||||
function_result = None
|
|
||||||
citation = None
|
|
||||||
try:
|
|
||||||
# Get the signature of the function
|
|
||||||
sig = inspect.signature(function)
|
|
||||||
params = result["parameters"]
|
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
|
||||||
extra_params = {
|
|
||||||
"__model__": model,
|
|
||||||
"__id__": tool_id,
|
|
||||||
"__messages__": messages,
|
|
||||||
"__files__": files,
|
|
||||||
"__event_emitter__": __event_emitter__,
|
|
||||||
"__event_call__": __event_call__,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
|
||||||
for key, value in extra_params.items():
|
|
||||||
if key in sig.parameters:
|
|
||||||
params[key] = value
|
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
|
||||||
# Call the function with the '__user__' parameter included
|
|
||||||
__user__ = {
|
__user__ = {
|
||||||
"id": user.id,
|
"id": user.id,
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
@ -414,15 +385,40 @@ async def get_function_call_response(
|
|||||||
__user__["valves"] = toolkit_module.UserValves(
|
__user__["valves"] = toolkit_module.UserValves(
|
||||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
extra_params = {
|
||||||
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
|
"__id__": tool_id,
|
||||||
|
"__messages__": messages,
|
||||||
|
"__files__": files,
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
"__event_call__": __event_call__,
|
||||||
|
"__user__": __user__,
|
||||||
|
}
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(function):
|
file_handler = hasattr(toolkit_module, "file_handler")
|
||||||
function_result = await function(**params)
|
|
||||||
else:
|
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
||||||
function_result = function(**params)
|
valves = Tools.get_tool_valves_by_id(tool_id)
|
||||||
|
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
content = await get_content_from_response(response)
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
return None, None, False
|
||||||
|
|
||||||
|
# Parse the function response
|
||||||
|
log.debug(f"content: {content}")
|
||||||
|
result = json.loads(content)
|
||||||
|
|
||||||
|
function_result = await call_tool_from_completion(
|
||||||
|
result, extra_params, toolkit_module
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
||||||
citation = {
|
citation = {
|
||||||
@ -430,8 +426,8 @@ async def get_function_call_response(
|
|||||||
"document": [function_result],
|
"document": [function_result],
|
||||||
"metadata": [{"source": result["name"]}],
|
"metadata": [{"source": result["name"]}],
|
||||||
}
|
}
|
||||||
except Exception as e:
|
else:
|
||||||
print(e)
|
citation = None
|
||||||
|
|
||||||
# Add the function result to the system prompt
|
# Add the function result to the system prompt
|
||||||
if function_result is not None:
|
if function_result is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user