put tool_ids and files in metadata

This commit is contained in:
Michael Poluektov 2024-08-20 15:41:49 +01:00
parent bcbcd5fde9
commit 2e3146263c
4 changed files with 24 additions and 22 deletions

View File

@ -731,12 +731,10 @@ async def generate_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
payload = {**form_data.model_dump(exclude_none=True)} payload = {**form_data.model_dump(exclude_none=True)}
for key in ["metadata", "files", "tool_ids"]: log.debug(f"{payload = }")
if key in payload: if "metadata" in payload:
del payload[key] del payload["metadata"]
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)

View File

@ -273,10 +273,12 @@ def get_function_params(function_module, form_data, user, extra_params={}):
return params return params
async def generate_function_chat_completion(form_data, user, files, tool_ids): async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None) metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
__event_emitter__ = None __event_emitter__ = None
__event_call__ = None __event_call__ = None

View File

@ -326,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
print(f"Error: {e}") print(f"Error: {e}")
raise e raise e
if skip_files and "files" in body: if skip_files and "files" in body.get("metadata", {}):
del body["files"] del body["metadata"]["files"]
return body, {} return body, {}
@ -371,7 +371,8 @@ async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
tool_ids = body.get("tool_ids", None) metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
if not tool_ids: if not tool_ids:
return body, {} return body, {}
@ -387,7 +388,7 @@ async def chat_completion_tools_handler(
**extra_params, **extra_params,
"__model__": app.state.MODELS[task_model_id], "__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"], "__messages__": body["messages"],
"__files__": body.get("files", []), "__files__": metadata.get("files", []),
} }
tools = get_tools(webui_app, tool_ids, user, custom_params) tools = get_tools(webui_app, tool_ids, user, custom_params)
log.info(f"{tools=}") log.info(f"{tools=}")
@ -454,8 +455,8 @@ async def chat_completion_tools_handler(
log.debug(f"tool_contexts: {contexts}") log.debug(f"tool_contexts: {contexts}")
if skip_files and "files" in body: if skip_files and "files" in body.get("metadata", {}):
del body["files"] del body["metadata"]["files"]
return body, {"contexts": contexts, "citations": citations} return body, {"contexts": contexts, "citations": citations}
@ -464,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = [] contexts = []
citations = [] citations = []
if files := body.get("files", None): if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
files=files, files=files,
messages=body["messages"], messages=body["messages"],
@ -986,11 +987,8 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
files = form_data.pop("files", [])
tool_ids = form_data.pop("tool_ids", [])
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user, files, tool_ids) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else:

View File

@ -844,8 +844,10 @@
}, },
format: $settings.requestFormat ?? undefined, format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, metadata: {
files: files.length > 0 ? files : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined
},
session_id: $socket?.id, session_id: $socket?.id,
chat_id: $chatId, chat_id: $chatId,
id: responseMessageId id: responseMessageId
@ -1136,8 +1138,10 @@
frequency_penalty: frequency_penalty:
params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined, params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined,
max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, metadata: {
files: files.length > 0 ? files : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined
},
session_id: $socket?.id, session_id: $socket?.id,
chat_id: $chatId, chat_id: $chatId,
id: responseMessageId id: responseMessageId