diff --git a/backend/main.py b/backend/main.py index 1d240249f..b14d56436 100644 --- a/backend/main.py +++ b/backend/main.py @@ -241,6 +241,12 @@ async def get_function_call_response( toolkit_module = load_toolkit_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = toolkit_module + file_handler = False + # check if toolkit_module has file_handler self variable + if hasattr(toolkit_module, "file_handler"): + file_handler = True + print("file_handler: ", file_handler) + function = getattr(toolkit_module, result["name"]) function_result = None try: @@ -279,12 +285,12 @@ async def get_function_call_response( print(e) # Add the function result to the system prompt - if function_result: - return function_result + if function_result is not None: + return function_result, file_handler except Exception as e: print(f"Error: {e}") - return None + return None, False class ChatCompletionMiddleware(BaseHTTPMiddleware): @@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): context = "" # If tool_ids field is present, call the functions + + skip_files = False if "tool_ids" in data: print(data["tool_ids"]) for tool_id in data["tool_ids"]: print(tool_id) try: - response = await get_function_call_response( + response, file_handler = await get_function_call_response( messages=data["messages"], files=data.get("files", []), tool_id=tool_id, @@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): user=user, ) + print(file_handler) if isinstance(response, str): context += ("\n" if context != "" else "") + response + if file_handler: + skip_files = True + except Exception as e: print(f"Error: {e}") del data["tool_ids"] print(f"tool_context: {context}") - # TODO: Check if tools & functions have files support to skip this step to delegate file processing # If files field is present, generate RAG completions + # If skip_files is True, skip the RAG completions if "files" in data: - data = {**data} - rag_context, citations = get_rag_context( - files=data["files"], - messages=data["messages"], - embedding_function=rag_app.state.EMBEDDING_FUNCTION, - k=rag_app.state.config.TOP_K, - reranking_function=rag_app.state.sentence_transformer_rf, - r=rag_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) + if not skip_files: + data = {**data} + rag_context, citations = get_rag_context( + files=data["files"], + messages=data["messages"], + embedding_function=rag_app.state.EMBEDDING_FUNCTION, + k=rag_app.state.config.TOP_K, + reranking_function=rag_app.state.sentence_transformer_rf, + r=rag_app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ) + if rag_context: + context += ("\n" if context != "" else "") + rag_context - if rag_context: - context += ("\n" if context != "" else "") + rag_context + log.debug(f"rag_context: {rag_context}, citations: {citations}") + else: + return_citations = False del data["files"] - log.debug(f"rag_context: {rag_context}, citations: {citations}") if context != "": system_prompt = rag_template( @@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE try: - context = await get_function_call_response( + context, file_handler = await get_function_call_response( form_data["messages"], form_data.get("files", []), form_data["tool_id"],