This commit is contained in:
Timothy J. Baek 2024-06-18 16:08:42 -07:00
parent bcc27e3852
commit 514c7f1520
3 changed files with 30 additions and 26 deletions

View File

@ -170,7 +170,9 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
async def get_function_call_response(messages, tool_id, template, task_model_id, user): async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs) content = tools_function_calling_generation_template(template, tools_specs)
@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
"__messages__": messages, "__messages__": messages,
} }
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
function_result = function(**params) function_result = function(**params)
except Exception as e: except Exception as e:
print(e) print(e)
@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
response = await get_function_call_response( response = await get_function_call_response(
messages=data["messages"], messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id, tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id, task_model_id=task_model_id,
@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(f"tool_context: {context}") print(f"tool_context: {context}")
# If docs field is present, generate RAG completions # TODO: Check if tools & functions have files support to skip this step to delegate file processing
# If files field is present, generate RAG completions
if "files" in data: if "files" in data:
data = {**data} data = {**data}
rag_context, citations = get_rag_context( rag_context, citations = get_rag_context(
@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
system_prompt = rag_template( system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt rag_app.state.config.RAG_TEMPLATE, context, prompt
) )
print(system_prompt) print(system_prompt)
data["messages"] = add_or_update_system_message( data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"] f"\n{system_prompt}", data["messages"]
) )
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
try: try:
context = await get_function_call_response( context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
) )
return context return context
except Exception as e: except Exception as e:

View File

@ -587,22 +587,17 @@
}); });
let files = []; let files = [];
if (model?.info?.meta?.knowledge ?? false) { if (model?.info?.meta?.knowledge ?? false) {
files = model.info.meta.knowledge; files = model.info.meta.knowledge;
} }
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
files = [ files = [
...files, ...files,
...messages ...(lastUserMessage?.files?.filter((item) =>
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type) ['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ) ?? [])
)
.flat(1)
].filter( ].filter(
// Remove duplicates
(item, index, array) => (item, index, array) =>
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
); );
@ -832,22 +827,17 @@
const responseMessage = history.messages[responseMessageId]; const responseMessage = history.messages[responseMessageId];
let files = []; let files = [];
if (model?.info?.meta?.knowledge ?? false) { if (model?.info?.meta?.knowledge ?? false) {
files = model.info.meta.knowledge; files = model.info.meta.knowledge;
} }
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
files = [ files = [
...files, ...files,
...messages ...(lastUserMessage?.files?.filter((item) =>
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type) ['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ) ?? [])
)
.flat(1)
].filter( ].filter(
// Remove duplicates
(item, index, array) => (item, index, array) =>
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
); );

View File

@ -153,6 +153,7 @@
if (res) { if (res) {
fileItem.status = 'processed'; fileItem.status = 'processed';
fileItem.collection_name = res.collection_name;
files = files; files = files;
} }
} catch (e) { } catch (e) {