diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 193743971..91b07e0aa 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,3 +1,4 @@ +import re from typing import List from config import CHROMA_CLIENT @@ -87,3 +88,10 @@ def query_collection( pass return merge_and_sort_query_results(results, k) + + +def rag_template(template: str, context: str, query: str): + template = re.sub(r"\[context\]", context, template) + template = re.sub(r"\[query\]", query, template) + + return template diff --git a/backend/main.py b/backend/main.py index afa974ca6..cc5edc0f9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,6 +12,7 @@ from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app @@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +from apps.rag.utils import query_doc, query_collection, rag_template + from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR from constants import ERROR_MESSAGES @@ -56,6 +59,89 @@ async def on_startup(): await litellm_app_startup() +class RAGMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + + print(request.url.path) + if request.method == "POST": + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + # Example: Add a new key-value pair or modify existing ones + # data["modified"] = True # Example modification + if "docs" in data: + docs = data["docs"] + print(docs) + + last_user_message_idx = None + for i in range(len(data["messages"]) - 1, -1, -1): + if data["messages"][i]["role"] == "user": + last_user_message_idx = i + break + + query = data["messages"][last_user_message_idx]["content"] + + relevant_contexts = [] + + for doc in docs: + context = None + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + content = rag_template( + template=rag_app.state.RAG_TEMPLATE, + context=context_string, + query=query, + ) + + new_user_message = { + **data["messages"][last_user_message_idx], + "content": content, + } + data["messages"][last_user_message_idx] = new_user_message + del data["docs"] + + print("DATAAAAAAAAAAAAAAAAAA") + print(data) + modified_body_bytes = json.dumps(data).encode("utf-8") + + # Create a new request with the modified body + scope = request.scope + scope["body"] = modified_body_bytes + request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(RAGMiddleware) + + @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 4e8e9b14c..6dcfbbe7d 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -252,7 +252,7 @@ export const queryCollection = async ( token: string, collection_names: string, query: string, - k: number + k: number | null = null ) => { let error = null; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 5b968ac90..e5510a066 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -232,53 +232,6 @@ const sendPrompt = async (prompt, parentId) => { const _chatId = JSON.parse(JSON.stringify($chatId)); - const docs = messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => item.type === 'doc' || item.type === 'collection') - ) - .flat(1); - - console.log(docs); - if (docs.length > 0) { - processing = 'Reading'; - const query = history.messages[parentId].content; - - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - if (doc.type === 'collection') { - return await queryCollection(localStorage.token, doc.collection_names, query).catch( - (error) => { - console.log(error); - return null; - } - ); - } else { - return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { - console.log(error); - return null; - }); - } - }) - ); - relevantContexts = relevantContexts.filter((context) => context); - - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); - - console.log(contextString); - - history.messages[parentId].raContent = await RAGTemplate( - localStorage.token, - contextString, - query - ); - history.messages[parentId].contexts = relevantContexts; - await tick(); - processing = ''; - } - await Promise.all( selectedModels.map(async (modelId) => { const model = $models.filter((m) => m.id === modelId).at(0); @@ -368,6 +321,13 @@ } }); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: messagesBody, @@ -375,7 +335,8 @@ ...($settings.options ?? {}) }, format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined + keep_alive: $settings.keepAlive ?? undefined, + docs: docs }); if (res && res.ok) { @@ -535,6 +496,13 @@ const responseMessage = history.messages[responseMessageId]; scrollToBottom(); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -583,7 +551,8 @@ top_p: $settings?.options?.top_p ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined + max_tokens: $settings?.options?.num_predict ?? undefined, + docs: docs }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` );