diff --git a/backend/main.py b/backend/main.py index 47078b681..346902de6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -247,6 +247,7 @@ async def get_function_call_response( result = json.loads(content) print(result) + citation = None # Call the function if "name" in result: if tool_id in webui_app.state.TOOLS: @@ -309,22 +310,32 @@ async def get_function_call_response( } function_result = function(**params) + + if hasattr(toolkit_module, "citation") and toolkit_module.citation: + citation = { + "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, + "document": [function_result], + "metadata": [{"source": result["name"]}], + } except Exception as e: print(e) # Add the function result to the system prompt if function_result is not None: - return function_result, file_handler + return function_result, citation, file_handler except Exception as e: print(f"Error: {e}") - return None, False + return None, None, False class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): data_items = [] + show_citations = False + citations = [] + if request.method == "POST" and any( endpoint in request.url.path for endpoint in ["/ollama/api/chat", "/chat/completions"] @@ -342,6 +353,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) # Flag to skip RAG completions if file_handler is present in tools/functions skip_files = False + if data.get("citations"): + show_citations = True + del data["citations"] model_id = data["model"] if model_id not in app.state.MODELS: @@ -365,8 +379,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): webui_app.state.FUNCTIONS[filter_id] = function_module # Check if the function has a file_handler variable - if getattr(function_module, "file_handler"): - skip_files = True + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler try: if hasattr(function_module, "inlet"): @@ -411,19 +425,25 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): for tool_id in data["tool_ids"]: print(tool_id) try: - response, file_handler = await get_function_call_response( - messages=data["messages"], - files=data.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, + response, citation, file_handler = ( + await get_function_call_response( + messages=data["messages"], + files=data.get("files", []), + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) ) print(file_handler) if isinstance(response, str): context += ("\n" if context != "" else "") + response + if citation: + citations.append(citation) + show_citations = True + if file_handler: skip_files = True @@ -438,7 +458,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if "files" in data: if not skip_files: data = {**data} - rag_context, citations = get_rag_context( + rag_context, rag_citations = get_rag_context( files=data["files"], messages=data["messages"], embedding_function=rag_app.state.EMBEDDING_FUNCTION, @@ -452,13 +472,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): log.debug(f"rag_context: {rag_context}, citations: {citations}") - if citations and data.get("citations"): - data_items.append({"citations": citations}) + if rag_citations: + citations.extend(rag_citations) del data["files"] - if data.get("citations"): - del data["citations"] + if show_citations and len(citations) > 0: + data_items.append({"citations": citations}) if context != "": system_prompt = rag_template( @@ -1285,7 +1305,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE try: - context, file_handler = await get_function_call_response( + context, citation, file_handler = await get_function_call_response( form_data["messages"], form_data.get("files", []), form_data["tool_id"],