refactor get_function_call_response

This commit is contained in:
Michael Poluektov 2024-08-10 14:25:20 +01:00
parent 9fb70969d7
commit a68b918cbb

View File

@ -297,6 +297,30 @@ async def get_content_from_response(response) -> Optional[str]:
return content
async def call_tool_from_completion(
result: dict, extra_params: dict, toolkit_module
) -> Optional[str]:
if "name" not in result:
return None
tool = getattr(toolkit_module, result["name"])
try:
# Get the signature of the function
sig = inspect.signature(tool)
params = result["parameters"]
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if inspect.iscoroutinefunction(tool):
return await tool(**params)
else:
return tool(**params)
except Exception as e:
print(f"Error: {e}")
return None
async def get_function_call_response(
messages,
files,
@ -306,7 +330,7 @@ async def get_function_call_response(
user,
__event_emitter__=None,
__event_call__=None,
):
) -> tuple[Optional[str], Optional[dict], bool]:
tool = Tools.get_tool_by_id(tool_id)
if tool is None:
return None, None, False
@ -343,7 +367,43 @@ async def get_function_call_response(
except Exception as e:
raise e
model = app.state.MODELS[task_model_id]
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
extra_params = {
"__model__": app.state.MODELS[task_model_id],
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__user__": __user__,
}
file_handler = hasattr(toolkit_module, "file_handler")
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
try:
response = await generate_chat_completions(form_data=payload, user=user)
@ -353,85 +413,21 @@ async def get_function_call_response(
return None, None, False
# Parse the function response
print(f"content: {content}")
log.debug(f"content: {content}")
result = json.loads(content)
print(result)
if "name" not in result:
return None, None, False
function_result = await call_tool_from_completion(
result, extra_params, toolkit_module
)
# Call the function
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
citation = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
else:
citation = None
# Add the function result to the system prompt
if function_result is not None: