From bd5a8567ef7c39fba67d3ae3f55b35412fb35255 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 11 Jun 2024 01:10:24 -0700 Subject: [PATCH] refac: tools & rag --- backend/apps/rag/utils.py | 14 ++---------- backend/main.py | 48 ++++++++++++++++++++------------------- 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7d92dd10f..d0570f748 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -236,10 +236,9 @@ def get_embedding_function( return lambda query: generate_multiple(query, func) -def rag_messages( +def get_rag_context( docs, messages, - template, embedding_function, k, reranking_function, @@ -318,16 +317,7 @@ def rag_messages( context_string = context_string.strip() - ra_content = rag_template( - template=template, - context=context_string, - query=query, - ) - - log.debug(f"ra_content: {ra_content}") - messages = add_or_update_system_message(ra_content, messages) - - return messages, citations + return context_string, citations def get_model_path(model: str, update_model: bool = False): diff --git a/backend/main.py b/backend/main.py index 4376da288..1e4a416dd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -64,7 +64,7 @@ from utils.task import ( ) from utils.misc import get_last_user_message, add_or_update_system_message -from apps.rag.utils import rag_messages, rag_template +from apps.rag.utils import get_rag_context, rag_template from config import ( CONFIG_DATA, @@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + # Remove the citations from the body return_citations = data.get("citations", False) if "citations" in data: @@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if "tool_ids" in data: - user = get_current_user( - get_http_authorization_cred(request.headers.get("Authorization")) - ) - prompt = get_last_user_message(data["messages"]) - context = "" + context = "" + # If tool_ids field is present, call the functions + if "tool_ids" in data: + prompt = get_last_user_message(data["messages"]) for tool_id in data["tool_ids"]: print(tool_id) response = await get_function_call_response( @@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if response: context += ("\n" if context != "" else "") + response - - if context != "": - system_prompt = rag_template( - rag_app.state.config.RAG_TEMPLATE, context, prompt - ) - - print(system_prompt) - - data["messages"] = add_or_update_system_message( - f"\n{system_prompt}", data["messages"] - ) - del data["tool_ids"] # If docs field is present, generate RAG completions if "docs" in data: data = {**data} - data["messages"], citations = rag_messages( + rag_context, citations = get_rag_context( docs=data["docs"], messages=data["messages"], - template=rag_app.state.config.RAG_TEMPLATE, embedding_function=rag_app.state.EMBEDDING_FUNCTION, k=rag_app.state.config.TOP_K, reranking_function=rag_app.state.sentence_transformer_rf, r=rag_app.state.config.RELEVANCE_THRESHOLD, hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) + + if rag_context: + context += ("\n" if context != "" else "") + rag_context + del data["docs"] - log.debug( - f"data['messages']: {data['messages']}, citations: {citations}" + log.debug(f"rag_context: {rag_context}, citations: {citations}") + + if context != "": + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) + + print(system_prompt) + + data["messages"] = add_or_update_system_message( + f"\n{system_prompt}", data["messages"] ) modified_body_bytes = json.dumps(data).encode("utf-8")