diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 193743971..91b07e0aa 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,3 +1,4 @@ +import re from typing import List from config import CHROMA_CLIENT @@ -87,3 +88,10 @@ def query_collection( pass return merge_and_sort_query_results(results, k) + + +def rag_template(template: str, context: str, query: str): + template = re.sub(r"\[context\]", context, template) + template = re.sub(r"\[query\]", query, template) + + return template diff --git a/backend/main.py b/backend/main.py index afa974ca6..e63f91a04 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,6 +12,7 @@ from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app @@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +from apps.rag.utils import query_doc, query_collection, rag_template + from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR from constants import ERROR_MESSAGES @@ -56,6 +59,124 @@ async def on_startup(): await litellm_app_startup() +class RAGMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if request.method == "POST" and ( + "/api/chat" in request.url.path or "/chat/completions" in request.url.path + ): + print(request.url.path) + + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + # Example: Add a new key-value pair or modify existing ones + # data["modified"] = True # Example modification + if "docs" in data: + docs = data["docs"] + print(docs) + + last_user_message_idx = None + for i in range(len(data["messages"]) - 1, -1, -1): + if data["messages"][i]["role"] == "user": + last_user_message_idx = i + break + + user_message = data["messages"][last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" + + relevant_contexts = [] + + for doc in docs: + context = None + + try: + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + except Exception as e: + print(e) + context = None + + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + ra_content = rag_template( + template=rag_app.state.RAG_TEMPLATE, + context=context_string, + query=query, + ) + + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + + data["messages"][last_user_message_idx] = new_user_message + del data["docs"] + + modified_body_bytes = json.dumps(data).encode("utf-8") + + # Create a new request with the modified body + scope = request.scope + scope["body"] = modified_body_bytes + request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(RAGMiddleware) + + @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 4e8e9b14c..6dcfbbe7d 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -252,7 +252,7 @@ export const queryCollection = async ( token: string, collection_names: string, query: string, - k: number + k: number | null = null ) => { let error = null; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 5b968ac90..bb3668dcc 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -232,53 +232,6 @@ const sendPrompt = async (prompt, parentId) => { const _chatId = JSON.parse(JSON.stringify($chatId)); - const docs = messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => item.type === 'doc' || item.type === 'collection') - ) - .flat(1); - - console.log(docs); - if (docs.length > 0) { - processing = 'Reading'; - const query = history.messages[parentId].content; - - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - if (doc.type === 'collection') { - return await queryCollection(localStorage.token, doc.collection_names, query).catch( - (error) => { - console.log(error); - return null; - } - ); - } else { - return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { - console.log(error); - return null; - }); - } - }) - ); - relevantContexts = relevantContexts.filter((context) => context); - - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); - - console.log(contextString); - - history.messages[parentId].raContent = await RAGTemplate( - localStorage.token, - contextString, - query - ); - history.messages[parentId].contexts = relevantContexts; - await tick(); - processing = ''; - } - await Promise.all( selectedModels.map(async (modelId) => { const model = $models.filter((m) => m.id === modelId).at(0); @@ -342,15 +295,25 @@ ...messages ] .filter((message) => message) - .map((message, idx, arr) => ({ - role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, - ...(message.files && { - images: message.files - .filter((file) => file.type === 'image') - .map((file) => file.url.slice(file.url.indexOf(',') + 1)) - }) - })); + .map((message, idx, arr) => { + // Prepare the base message object + const baseMessage = { + role: message.role, + content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + }; + + // Extract and format image URLs if any exist + const imageUrls = message.files + ?.filter((file) => file.type === 'image') + .map((file) => file.url.slice(file.url.indexOf(',') + 1)); + + // Add images array only if it contains elements + if (imageUrls && imageUrls.length > 0) { + baseMessage.images = imageUrls; + } + + return baseMessage; + }); let lastImageIndex = -1; @@ -368,6 +331,13 @@ } }); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: messagesBody, @@ -375,7 +345,8 @@ ...($settings.options ?? {}) }, format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined + keep_alive: $settings.keepAlive ?? undefined, + docs: docs.length > 0 ? docs : undefined }); if (res && res.ok) { @@ -535,6 +506,15 @@ const responseMessage = history.messages[responseMessageId]; scrollToBottom(); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + + console.log(docs); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -583,7 +563,8 @@ top_p: $settings?.options?.top_p ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined + max_tokens: $settings?.options?.num_predict ?? undefined, + docs: docs.length > 0 ? docs : undefined }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index dc9f8a580..4bc6acfa2 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -245,53 +245,6 @@ const sendPrompt = async (prompt, parentId) => { const _chatId = JSON.parse(JSON.stringify($chatId)); - const docs = messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => item.type === 'doc' || item.type === 'collection') - ) - .flat(1); - - console.log(docs); - if (docs.length > 0) { - processing = 'Reading'; - const query = history.messages[parentId].content; - - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - if (doc.type === 'collection') { - return await queryCollection(localStorage.token, doc.collection_names, query).catch( - (error) => { - console.log(error); - return null; - } - ); - } else { - return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { - console.log(error); - return null; - }); - } - }) - ); - relevantContexts = relevantContexts.filter((context) => context); - - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); - - console.log(contextString); - - history.messages[parentId].raContent = await RAGTemplate( - localStorage.token, - contextString, - query - ); - history.messages[parentId].contexts = relevantContexts; - await tick(); - processing = ''; - } - await Promise.all( selectedModels.map(async (modelId) => { const model = $models.filter((m) => m.id === modelId).at(0); @@ -355,15 +308,25 @@ ...messages ] .filter((message) => message) - .map((message, idx, arr) => ({ - role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, - ...(message.files && { - images: message.files - .filter((file) => file.type === 'image') - .map((file) => file.url.slice(file.url.indexOf(',') + 1)) - }) - })); + .map((message, idx, arr) => { + // Prepare the base message object + const baseMessage = { + role: message.role, + content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + }; + + // Extract and format image URLs if any exist + const imageUrls = message.files + ?.filter((file) => file.type === 'image') + .map((file) => file.url.slice(file.url.indexOf(',') + 1)); + + // Add images array only if it contains elements + if (imageUrls && imageUrls.length > 0) { + baseMessage.images = imageUrls; + } + + return baseMessage; + }); let lastImageIndex = -1; @@ -381,6 +344,13 @@ } }); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: messagesBody, @@ -388,7 +358,8 @@ ...($settings.options ?? {}) }, format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined + keep_alive: $settings.keepAlive ?? undefined, + docs: docs.length > 0 ? docs : undefined }); if (res && res.ok) { @@ -548,6 +519,15 @@ const responseMessage = history.messages[responseMessageId]; scrollToBottom(); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + + console.log(docs); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -596,7 +576,8 @@ top_p: $settings?.options?.top_p ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined + max_tokens: $settings?.options?.num_predict ?? undefined, + docs: docs.length > 0 ? docs : undefined }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); @@ -710,6 +691,7 @@ await setChatTitle(_chatId, userPrompt); } }; + const stopResponse = () => { stopResponseFlag = true; console.log('stopResponse');