diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 91b07e0aa..b2da7d90c 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str): template = re.sub(r"\[query\]", query, template) return template + + +def rag_messages(docs, messages, template, k, embedding_function): + print(docs) + + last_user_message_idx = None + for i in range(len(messages) - 1, -1, -1): + if messages[i]["role"] == "user": + last_user_message_idx = i + break + + user_message = messages[last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" + + relevant_contexts = [] + + for doc in docs: + context = None + + try: + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=k, + embedding_function=embedding_function, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=k, + embedding_function=embedding_function, + ) + except Exception as e: + print(e) + context = None + + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + ra_content = rag_template( + template=template, + context=context_string, + query=query, + ) + + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + + messages[last_user_message_idx] = new_user_message + + return messages diff --git a/backend/main.py b/backend/main.py index c7523ec62..253227182 100644 --- a/backend/main.py +++ b/backend/main.py @@ -28,7 +28,7 @@ from typing import List from utils.utils import get_admin_user -from apps.rag.utils import query_doc, query_collection, rag_template +from apps.rag.utils import rag_messages from config import ( WEBUI_NAME, @@ -60,19 +60,6 @@ app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST origins = ["*"] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.on_event("startup") -async def on_startup(): - await litellm_app_startup() - class RAGMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): @@ -91,98 +78,33 @@ class RAGMiddleware(BaseHTTPMiddleware): # Example: Add a new key-value pair or modify existing ones # data["modified"] = True # Example modification if "docs" in data: - docs = data["docs"] - print(docs) - last_user_message_idx = None - for i in range(len(data["messages"]) - 1, -1, -1): - if data["messages"][i]["role"] == "user": - last_user_message_idx = i - break - - user_message = data["messages"][last_user_message_idx] - - if isinstance(user_message["content"], list): - # Handle list content input - content_type = "list" - query = "" - for content_item in user_message["content"]: - if content_item["type"] == "text": - query = content_item["text"] - break - elif isinstance(user_message["content"], str): - # Handle text content input - content_type = "text" - query = user_message["content"] - else: - # Fallback in case the input does not match expected types - content_type = None - query = "" - - relevant_contexts = [] - - for doc in docs: - context = None - - try: - if doc["type"] == "collection": - context = query_collection( - collection_names=doc["collection_names"], - query=query, - k=rag_app.state.TOP_K, - embedding_function=rag_app.state.sentence_transformer_ef, - ) - else: - context = query_doc( - collection_name=doc["collection_name"], - query=query, - k=rag_app.state.TOP_K, - embedding_function=rag_app.state.sentence_transformer_ef, - ) - except Exception as e: - print(e) - context = None - - relevant_contexts.append(context) - - context_string = "" - for context in relevant_contexts: - if context: - context_string += " ".join(context["documents"][0]) + "\n" - - ra_content = rag_template( - template=rag_app.state.RAG_TEMPLATE, - context=context_string, - query=query, + data = {**data} + data["messages"] = rag_messages( + data["docs"], + data["messages"], + rag_app.state.RAG_TEMPLATE, + rag_app.state.TOP_K, + rag_app.state.sentence_transformer_ef, ) - - if content_type == "list": - new_content = [] - for content_item in user_message["content"]: - if content_item["type"] == "text": - # Update the text item's content with ra_content - new_content.append({"type": "text", "text": ra_content}) - else: - # Keep other types of content as they are - new_content.append(content_item) - new_user_message = {**user_message, "content": new_content} - else: - new_user_message = { - **user_message, - "content": ra_content, - } - - data["messages"][last_user_message_idx] = new_user_message del data["docs"] print(data["messages"]) modified_body_bytes = json.dumps(data).encode("utf-8") - # Create a new request with the modified body - scope = request.scope - scope["body"] = modified_body_bytes - request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) + # Replace the request body with the modified one + request._body = modified_body_bytes + + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[ + (k, v) + for k, v in request.headers.raw + if k.lower() != b"content-length" + ], + ] response = await call_next(request) return response @@ -194,6 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware): app.add_middleware(RAGMiddleware) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) @@ -204,6 +135,11 @@ async def check_url(request: Request, call_next): return response +@app.on_event("startup") +async def on_startup(): + await litellm_app_startup() + + app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app)