fix: adapting to the new get_task_model_id signature

This commit is contained in:
Samuel 2024-11-17 17:50:15 +00:00
parent d400e65601
commit b8ca03fafd

View File

@ -351,7 +351,7 @@ async def get_content_from_response(response) -> Optional[str]:
return content
def get_tools_body(
body: dict, user: UserModel, extra_params: dict
body: dict, user: UserModel, extra_params: dict, models
) -> tuple[dict, dict]:
metadata = body.get("metadata", {})
@ -360,14 +360,20 @@ def get_tools_body(
if not tool_ids:
return body, {}
task_model_id = get_task_model_id(body["model"])
task_model_id = get_task_model_id(
body["model"],
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
tools = get_tools(
webui_app,
tool_ids,
user,
{
**extra_params,
"__model__": app.state.MODELS[task_model_id],
"__model__": models[task_model_id],
"__messages__": body["messages"],
"__files__": metadata.get("files", []),
},
@ -861,7 +867,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
}
body["metadata"] = metadata
body, tools = get_tools_body(body, user, extra_params)
body, tools = get_tools_body(body, user, extra_params, models)
if model["owned_by"] == "ollama" and \
body["metadata"]["native_tool_call"] and \