From 514c7f1520783cffc8477e95048f31f5a803e31c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 18 Jun 2024 16:08:42 -0700 Subject: [PATCH] fix: rag --- backend/main.py | 25 ++++++++++++----- src/lib/components/chat/Chat.svelte | 30 +++++++-------------- src/lib/components/chat/MessageInput.svelte | 1 + 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/backend/main.py b/backend/main.py index 5950dabee..1d240249f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -170,7 +170,9 @@ app.state.MODELS = {} origins = ["*"] -async def get_function_call_response(messages, tool_id, template, task_model_id, user): +async def get_function_call_response( + messages, files, tool_id, template, task_model_id, user +): tool = Tools.get_tool_by_id(tool_id) tools_specs = json.dumps(tool.specs, indent=2) content = tools_function_calling_generation_template(template, tools_specs) @@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, "__messages__": messages, } + if "__files__" in sig.parameters: + # Call the function with the '__files__' parameter included + params = { + **params, + "__files__": files, + } + function_result = function(**params) except Exception as e: print(e) @@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: response = await get_function_call_response( messages=data["messages"], + files=data.get("files", []), tool_id=tool_id, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, task_model_id=task_model_id, @@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(f"tool_context: {context}") - # If docs field is present, generate RAG completions + # TODO: Check if tools & functions have files support to skip this step to delegate file processing + # If files field is present, generate RAG completions if "files" in data: data = {**data} rag_context, citations = get_rag_context( @@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): system_prompt = rag_template( rag_app.state.config.RAG_TEMPLATE, context, prompt ) - print(system_prompt) - data["messages"] = add_or_update_system_message( f"\n{system_prompt}", data["messages"] ) modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one request._body = modified_body_bytes # Set custom header to ensure content-length matches new body length @@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ try: context = await get_function_call_response( - form_data["messages"], form_data["tool_id"], template, model_id, user + form_data["messages"], + form_data.get("files", []), + form_data["tool_id"], + template, + model_id, + user, ) return context except Exception as e: diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 4fbd50993..c52a726cd 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -587,22 +587,17 @@ }); let files = []; - if (model?.info?.meta?.knowledge ?? false) { files = model.info.meta.knowledge; } - + const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); files = [ ...files, - ...messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => - ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) - ) - ) - .flat(1) + ...(lastUserMessage?.files?.filter((item) => + ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) + ) ?? []) ].filter( + // Remove duplicates (item, index, array) => array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index ); @@ -832,22 +827,17 @@ const responseMessage = history.messages[responseMessageId]; let files = []; - if (model?.info?.meta?.knowledge ?? false) { files = model.info.meta.knowledge; } - + const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); files = [ ...files, - ...messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => - ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) - ) - ) - .flat(1) + ...(lastUserMessage?.files?.filter((item) => + ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) + ) ?? []) ].filter( + // Remove duplicates (item, index, array) => array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index ); diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index f199b9154..eca040d83 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -153,6 +153,7 @@ if (res) { fileItem.status = 'processed'; + fileItem.collection_name = res.collection_name; files = files; } } catch (e) {