add __tools__ custom param

This commit is contained in:
Michael Poluektov 2024-08-19 11:08:27 +01:00
parent 18965dcdac
commit 13c03bfd7d
2 changed files with 22 additions and 16 deletions

View File

@ -26,6 +26,7 @@ from utils.misc import (
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from utils.tools import get_tools
from config import ( from config import (
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
@ -47,6 +48,7 @@ from config import (
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_EMAIL_CLAIM, OAUTH_EMAIL_CLAIM,
ENABLE_TOOLS_FILTER,
) )
from apps.socket.main import get_event_call, get_event_emitter from apps.socket.main import get_event_call, get_event_emitter
@ -271,7 +273,7 @@ def get_function_params(function_module, form_data, user, extra_params={}):
return params return params
async def generate_function_chat_completion(form_data, user): async def generate_function_chat_completion(form_data, user, files, tool_ids):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None) metadata = form_data.pop("metadata", None)
@ -286,6 +288,21 @@ async def generate_function_chat_completion(form_data, user):
__event_call__ = get_event_call(metadata) __event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None) __task__ = metadata.get("task", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
}
if not ENABLE_TOOLS_FILTER:
tools_params = {
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
}
configured_tools = get_tools(app, tool_ids, user, tools_params)
extra_params["__tools__"] = configured_tools
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
@ -298,16 +315,7 @@ async def generate_function_chat_completion(form_data, user):
function_module = get_function_module(pipe_id) function_module = get_function_module(pipe_id)
pipe = function_module.pipe pipe = function_module.pipe
params = get_function_params( params = get_function_params(function_module, form_data, user, extra_params)
function_module,
form_data,
user,
{
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
},
)
if form_data["stream"]: if form_data["stream"]:

View File

@ -994,13 +994,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
files = form_data.pop("files", None)
tool_ids = form_data.pop("tool_ids", None)
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user, files, tool_ids)
for key in ["tool_ids", "files"]:
if key in form_data:
del form_data[key]
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else: