This commit is contained in:
Timothy J. Baek 2024-06-18 16:08:42 -07:00
parent bcc27e3852
commit 514c7f1520
3 changed files with 30 additions and 26 deletions

View File

@ -170,7 +170,9 @@ app.state.MODELS = {}
origins = ["*"]
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
function_result = function(**params)
except Exception as e:
print(e)
@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
response = await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
# If files field is present, generate RAG completions
if "files" in data:
data = {**data}
rag_context, citations = get_rag_context(
@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
)
return context
except Exception as e:

View File

@ -587,22 +587,17 @@
});
let files = [];
if (model?.info?.meta?.knowledge ?? false) {
files = model.info.meta.knowledge;
}
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
files = [
...files,
...messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
)
)
.flat(1)
...(lastUserMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? [])
].filter(
// Remove duplicates
(item, index, array) =>
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
);
@ -832,22 +827,17 @@
const responseMessage = history.messages[responseMessageId];
let files = [];
if (model?.info?.meta?.knowledge ?? false) {
files = model.info.meta.knowledge;
}
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
files = [
...files,
...messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
)
)
.flat(1)
...(lastUserMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? [])
].filter(
// Remove duplicates
(item, index, array) =>
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
);

View File

@ -153,6 +153,7 @@
if (res) {
fileItem.status = 'processed';
fileItem.collection_name = res.collection_name;
files = files;
}
} catch (e) {