From 5270efa9e5cf5d39de326930344ac94611622935 Mon Sep 17 00:00:00 2001
From: "Timothy J. Baek" <timothyjrbeck@gmail.com>
Date: Sat, 17 Feb 2024 22:41:03 -0800
Subject: [PATCH] feat: editable rag template

---
 backend/apps/rag/main.py             | 22 +++++++++++
 backend/config.py                    | 15 ++++++++
 src/lib/apis/rag/index.ts            | 57 ++++++++++++++++++++++++++++
 src/lib/utils/rag/index.ts           | 32 +++++++++-------
 src/routes/(app)/+page.svelte        |  6 ++-
 src/routes/(app)/c/[id]/+page.svelte |  6 ++-
 6 files changed, 122 insertions(+), 16 deletions(-)

diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index 79981680f..fbfba258d 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -62,6 +62,7 @@ from config import (
     CHROMA_CLIENT,
     CHUNK_SIZE,
     CHUNK_OVERLAP,
+    RAG_TEMPLATE,
 )
 from constants import ERROR_MESSAGES
 
@@ -73,6 +74,8 @@ app = FastAPI()
 
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
+app.state.RAG_TEMPLATE = RAG_TEMPLATE
+
 
 origins = ["*"]
 
@@ -154,6 +157,25 @@ async def update_chunk_params(
     }
 
 
+@app.get("/template")
+async def get_rag_template(user=Depends(get_current_user)):
+    return {
+        "status": True,
+        "template": app.state.RAG_TEMPLATE,
+    }
+
+
+class RAGTemplateForm(BaseModel):
+    template: str
+
+
+@app.post("/template/update")
+async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
+    # TODO: check template requirements
+    app.state.RAG_TEMPLATE = form_data.template
+    return {"status": True, "template": app.state.RAG_TEMPLATE}
+
+
 class QueryDocForm(BaseModel):
     collection_name: str
     query: str
diff --git a/backend/config.py b/backend/config.py
index f5acf06b7..440256c48 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -144,6 +144,21 @@ CHROMA_CLIENT = chromadb.PersistentClient(
 CHUNK_SIZE = 1500
 CHUNK_OVERLAP = 100
 
+
+RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
+<context>
+    [context]
+</context>
+
+When answer to user:
+- If you don't know, just say that you don't know.
+- If you don't know when you are not sure, ask for clarification.
+Avoid mentioning that you obtained the information from the context.
+And answer according to the language of the user's question.
+        
+Given the context information, answer the query.
+Query: [query]"""
+
 ####################################
 # Transcribe
 ####################################
diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts
index 5819badbd..78c220b6f 100644
--- a/src/lib/apis/rag/index.ts
+++ b/src/lib/apis/rag/index.ts
@@ -58,6 +58,63 @@ export const updateChunkParams = async (token: string, size: number, overlap: nu
 	return res;
 };
 
+export const getRAGTemplate = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/template`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateRAGTemplate = async (token: string, template: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			template: template
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
 	const data = new FormData();
 	data.append('file', file);
diff --git a/src/lib/utils/rag/index.ts b/src/lib/utils/rag/index.ts
index 6b219ef20..ba1f29f88 100644
--- a/src/lib/utils/rag/index.ts
+++ b/src/lib/utils/rag/index.ts
@@ -1,17 +1,21 @@
-export const RAGTemplate = (context: string, query: string) => {
-	let template = `Use the following context as your learned knowledge, inside <context></context> XML tags.
-	<context>
-	  [context]
-	</context>
-	
-	When answer to user:
-	- If you don't know, just say that you don't know.
-	- If you don't know when you are not sure, ask for clarification.
-	Avoid mentioning that you obtained the information from the context.
-	And answer according to the language of the user's question.
-			
-	Given the context information, answer the query.
-	Query: [query]`;
+import { getRAGTemplate } from '$lib/apis/rag';
+
+export const RAGTemplate = async (token: string, context: string, query: string) => {
+	let template = await getRAGTemplate(token).catch(() => {
+		return `Use the following context as your learned knowledge, inside <context></context> XML tags.
+		<context>
+		  [context]
+		</context>
+		
+		When answer to user:
+		- If you don't know, just say that you don't know.
+		- If you don't know when you are not sure, ask for clarification.
+		Avoid mentioning that you obtained the information from the context.
+		And answer according to the language of the user's question.
+				
+		Given the context information, answer the query.
+		Query: [query]`;
+	});
 
 	template = template.replace(/\[context\]/g, context);
 	template = template.replace(/\[query\]/g, query);
diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte
index 604cb544d..1d91a6144 100644
--- a/src/routes/(app)/+page.svelte
+++ b/src/routes/(app)/+page.svelte
@@ -266,7 +266,11 @@
 
 			console.log(contextString);
 
-			history.messages[parentId].raContent = RAGTemplate(contextString, query);
+			history.messages[parentId].raContent = await RAGTemplate(
+				localStorage.token,
+				contextString,
+				query
+			);
 			history.messages[parentId].contexts = relevantContexts;
 			await tick();
 			processing = '';
diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte
index aab03d74f..b719ebf2b 100644
--- a/src/routes/(app)/c/[id]/+page.svelte
+++ b/src/routes/(app)/c/[id]/+page.svelte
@@ -280,7 +280,11 @@
 
 			console.log(contextString);
 
-			history.messages[parentId].raContent = RAGTemplate(contextString, query);
+			history.messages[parentId].raContent = await RAGTemplate(
+				localStorage.token,
+				contextString,
+				query
+			);
 			history.messages[parentId].contexts = relevantContexts;
 			await tick();
 			processing = '';