From c49491e516aaa6023a21ae320c676048a32743cd Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 22:34:47 -0800 Subject: [PATCH 1/6] refac: rag to backend --- backend/apps/rag/utils.py | 8 ++++ backend/main.py | 86 +++++++++++++++++++++++++++++++++++ src/lib/apis/rag/index.ts | 2 +- src/routes/(app)/+page.svelte | 67 ++++++++------------------- 4 files changed, 113 insertions(+), 50 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 193743971..91b07e0aa 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,3 +1,4 @@ +import re from typing import List from config import CHROMA_CLIENT @@ -87,3 +88,10 @@ def query_collection( pass return merge_and_sort_query_results(results, k) + + +def rag_template(template: str, context: str, query: str): + template = re.sub(r"\[context\]", context, template) + template = re.sub(r"\[query\]", query, template) + + return template diff --git a/backend/main.py b/backend/main.py index afa974ca6..cc5edc0f9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,6 +12,7 @@ from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app @@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +from apps.rag.utils import query_doc, query_collection, rag_template + from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR from constants import ERROR_MESSAGES @@ -56,6 +59,89 @@ async def on_startup(): await litellm_app_startup() +class RAGMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + + print(request.url.path) + if request.method == "POST": + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + # Example: Add a new key-value pair or modify existing ones + # data["modified"] = True # Example modification + if "docs" in data: + docs = data["docs"] + print(docs) + + last_user_message_idx = None + for i in range(len(data["messages"]) - 1, -1, -1): + if data["messages"][i]["role"] == "user": + last_user_message_idx = i + break + + query = data["messages"][last_user_message_idx]["content"] + + relevant_contexts = [] + + for doc in docs: + context = None + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + content = rag_template( + template=rag_app.state.RAG_TEMPLATE, + context=context_string, + query=query, + ) + + new_user_message = { + **data["messages"][last_user_message_idx], + "content": content, + } + data["messages"][last_user_message_idx] = new_user_message + del data["docs"] + + print("DATAAAAAAAAAAAAAAAAAA") + print(data) + modified_body_bytes = json.dumps(data).encode("utf-8") + + # Create a new request with the modified body + scope = request.scope + scope["body"] = modified_body_bytes + request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(RAGMiddleware) + + @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 4e8e9b14c..6dcfbbe7d 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -252,7 +252,7 @@ export const queryCollection = async ( token: string, collection_names: string, query: string, - k: number + k: number | null = null ) => { let error = null; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 5b968ac90..e5510a066 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -232,53 +232,6 @@ const sendPrompt = async (prompt, parentId) => { const _chatId = JSON.parse(JSON.stringify($chatId)); - const docs = messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => item.type === 'doc' || item.type === 'collection') - ) - .flat(1); - - console.log(docs); - if (docs.length > 0) { - processing = 'Reading'; - const query = history.messages[parentId].content; - - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - if (doc.type === 'collection') { - return await queryCollection(localStorage.token, doc.collection_names, query).catch( - (error) => { - console.log(error); - return null; - } - ); - } else { - return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { - console.log(error); - return null; - }); - } - }) - ); - relevantContexts = relevantContexts.filter((context) => context); - - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); - - console.log(contextString); - - history.messages[parentId].raContent = await RAGTemplate( - localStorage.token, - contextString, - query - ); - history.messages[parentId].contexts = relevantContexts; - await tick(); - processing = ''; - } - await Promise.all( selectedModels.map(async (modelId) => { const model = $models.filter((m) => m.id === modelId).at(0); @@ -368,6 +321,13 @@ } }); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: messagesBody, @@ -375,7 +335,8 @@ ...($settings.options ?? {}) }, format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined + keep_alive: $settings.keepAlive ?? undefined, + docs: docs }); if (res && res.ok) { @@ -535,6 +496,13 @@ const responseMessage = history.messages[responseMessageId]; scrollToBottom(); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -583,7 +551,8 @@ top_p: $settings?.options?.top_p ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined + max_tokens: $settings?.options?.num_predict ?? undefined, + docs: docs }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); From 6c58bb59bed875752fe3cb90edc499da7bb72957 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 22:43:06 -0800 Subject: [PATCH 2/6] feat: rag docs as payload field --- backend/main.py | 2 - src/routes/(app)/+page.svelte | 6 ++- src/routes/(app)/c/[id]/+page.svelte | 70 +++++++++------------------- 3 files changed, 25 insertions(+), 53 deletions(-) diff --git a/backend/main.py b/backend/main.py index cc5edc0f9..d36c8420e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -123,8 +123,6 @@ class RAGMiddleware(BaseHTTPMiddleware): data["messages"][last_user_message_idx] = new_user_message del data["docs"] - print("DATAAAAAAAAAAAAAAAAAA") - print(data) modified_body_bytes = json.dumps(data).encode("utf-8") # Create a new request with the modified body diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index e5510a066..28bd8eb6f 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -336,7 +336,7 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, - docs: docs + docs: docs.length > 0 ? docs : undefined }); if (res && res.ok) { @@ -503,6 +503,8 @@ ) .flat(1); + console.log(docs); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -552,7 +554,7 @@ num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, max_tokens: $settings?.options?.num_predict ?? undefined, - docs: docs + docs: docs.length > 0 ? docs : undefined }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index dc9f8a580..0ec3fae40 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -245,53 +245,6 @@ const sendPrompt = async (prompt, parentId) => { const _chatId = JSON.parse(JSON.stringify($chatId)); - const docs = messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => item.type === 'doc' || item.type === 'collection') - ) - .flat(1); - - console.log(docs); - if (docs.length > 0) { - processing = 'Reading'; - const query = history.messages[parentId].content; - - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - if (doc.type === 'collection') { - return await queryCollection(localStorage.token, doc.collection_names, query).catch( - (error) => { - console.log(error); - return null; - } - ); - } else { - return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { - console.log(error); - return null; - }); - } - }) - ); - relevantContexts = relevantContexts.filter((context) => context); - - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); - - console.log(contextString); - - history.messages[parentId].raContent = await RAGTemplate( - localStorage.token, - contextString, - query - ); - history.messages[parentId].contexts = relevantContexts; - await tick(); - processing = ''; - } - await Promise.all( selectedModels.map(async (modelId) => { const model = $models.filter((m) => m.id === modelId).at(0); @@ -381,6 +334,13 @@ } }); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: messagesBody, @@ -388,7 +348,8 @@ ...($settings.options ?? {}) }, format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined + keep_alive: $settings.keepAlive ?? undefined, + docs: docs.length > 0 ? docs : undefined }); if (res && res.ok) { @@ -548,6 +509,15 @@ const responseMessage = history.messages[responseMessageId]; scrollToBottom(); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + + console.log(docs); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -596,7 +566,8 @@ top_p: $settings?.options?.top_p ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined + max_tokens: $settings?.options?.num_predict ?? undefined, + docs: docs.length > 0 ? docs : undefined }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); @@ -710,6 +681,7 @@ await setChatTitle(_chatId, userPrompt); } }; + const stopResponse = () => { stopResponseFlag = true; console.log('stopResponse'); From dfcc31428337670c4858a1729d64fbfe9e34a1dd Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 22:51:42 -0800 Subject: [PATCH 3/6] fix: only edit body with whitelisted paths --- backend/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index d36c8420e..330c9b293 100644 --- a/backend/main.py +++ b/backend/main.py @@ -62,8 +62,12 @@ async def on_startup(): class RAGMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): + "chat/completions" in request.url.path + print(request.url.path) - if request.method == "POST": + if request.method == "POST" and ( + "/api/chat" in request.url.path or "/chat/completions" in request.url.path + ): # Read the original request body body = await request.body() # Decode body to string From 9f58ed5afac25d056112c9cc7dc71a9da4df859c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 22:52:42 -0800 Subject: [PATCH 4/6] fix --- backend/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index 330c9b293..bb424ae0b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -61,13 +61,11 @@ async def on_startup(): class RAGMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - - "chat/completions" in request.url.path - - print(request.url.path) if request.method == "POST" and ( "/api/chat" in request.url.path or "/chat/completions" in request.url.path ): + print(request.url.path) + # Read the original request body body = await request.body() # Decode body to string From d936353da0f6131d0cf4157f02855902d78cb159 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 23:19:20 -0800 Subject: [PATCH 5/6] fix: message type edge case --- backend/main.py | 41 ++++++++++++++++++++++++---- src/routes/(app)/+page.svelte | 28 +++++++++++++------ src/routes/(app)/c/[id]/+page.svelte | 28 +++++++++++++------ 3 files changed, 73 insertions(+), 24 deletions(-) diff --git a/backend/main.py b/backend/main.py index bb424ae0b..11ca81fc5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -85,7 +85,24 @@ class RAGMiddleware(BaseHTTPMiddleware): last_user_message_idx = i break - query = data["messages"][last_user_message_idx]["content"] + user_message = data["messages"][last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" relevant_contexts = [] @@ -112,16 +129,28 @@ class RAGMiddleware(BaseHTTPMiddleware): if context: context_string += " ".join(context["documents"][0]) + "\n" - content = rag_template( + ra_content = rag_template( template=rag_app.state.RAG_TEMPLATE, context=context_string, query=query, ) - new_user_message = { - **data["messages"][last_user_message_idx], - "content": content, - } + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + data["messages"][last_user_message_idx] = new_user_message del data["docs"] diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 28bd8eb6f..bb3668dcc 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -295,15 +295,25 @@ ...messages ] .filter((message) => message) - .map((message, idx, arr) => ({ - role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, - ...(message.files && { - images: message.files - .filter((file) => file.type === 'image') - .map((file) => file.url.slice(file.url.indexOf(',') + 1)) - }) - })); + .map((message, idx, arr) => { + // Prepare the base message object + const baseMessage = { + role: message.role, + content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + }; + + // Extract and format image URLs if any exist + const imageUrls = message.files + ?.filter((file) => file.type === 'image') + .map((file) => file.url.slice(file.url.indexOf(',') + 1)); + + // Add images array only if it contains elements + if (imageUrls && imageUrls.length > 0) { + baseMessage.images = imageUrls; + } + + return baseMessage; + }); let lastImageIndex = -1; diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 0ec3fae40..4bc6acfa2 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -308,15 +308,25 @@ ...messages ] .filter((message) => message) - .map((message, idx, arr) => ({ - role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, - ...(message.files && { - images: message.files - .filter((file) => file.type === 'image') - .map((file) => file.url.slice(file.url.indexOf(',') + 1)) - }) - })); + .map((message, idx, arr) => { + // Prepare the base message object + const baseMessage = { + role: message.role, + content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + }; + + // Extract and format image URLs if any exist + const imageUrls = message.files + ?.filter((file) => file.type === 'image') + .map((file) => file.url.slice(file.url.indexOf(',') + 1)); + + // Add images array only if it contains elements + if (imageUrls && imageUrls.length > 0) { + baseMessage.images = imageUrls; + } + + return baseMessage; + }); let lastImageIndex = -1; From 784ee6f52183ab2158359d2bf75535ca7ad740d2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 23:21:00 -0800 Subject: [PATCH 6/6] fix: error handling --- backend/main.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/backend/main.py b/backend/main.py index 11ca81fc5..e63f91a04 100644 --- a/backend/main.py +++ b/backend/main.py @@ -108,20 +108,26 @@ class RAGMiddleware(BaseHTTPMiddleware): for doc in docs: context = None - if doc["type"] == "collection": - context = query_collection( - collection_names=doc["collection_names"], - query=query, - k=rag_app.state.TOP_K, - embedding_function=rag_app.state.sentence_transformer_ef, - ) - else: - context = query_doc( - collection_name=doc["collection_name"], - query=query, - k=rag_app.state.TOP_K, - embedding_function=rag_app.state.sentence_transformer_ef, - ) + + try: + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + except Exception as e: + print(e) + context = None + relevant_contexts.append(context) context_string = ""