diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 79981680f..fbfba258d 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -62,6 +62,7 @@ from config import ( CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, + RAG_TEMPLATE, ) from constants import ERROR_MESSAGES @@ -73,6 +74,8 @@ app = FastAPI() app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP +app.state.RAG_TEMPLATE = RAG_TEMPLATE + origins = ["*"] @@ -154,6 +157,25 @@ async def update_chunk_params( } +@app.get("/template") +async def get_rag_template(user=Depends(get_current_user)): + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + } + + +class RAGTemplateForm(BaseModel): + template: str + + +@app.post("/template/update") +async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)): + # TODO: check template requirements + app.state.RAG_TEMPLATE = form_data.template + return {"status": True, "template": app.state.RAG_TEMPLATE} + + class QueryDocForm(BaseModel): collection_name: str query: str diff --git a/backend/config.py b/backend/config.py index f5acf06b7..440256c48 100644 --- a/backend/config.py +++ b/backend/config.py @@ -144,6 +144,21 @@ CHROMA_CLIENT = chromadb.PersistentClient( CHUNK_SIZE = 1500 CHUNK_OVERLAP = 100 + +RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags. + + [context] + + +When answer to user: +- If you don't know, just say that you don't know. +- If you don't know when you are not sure, ask for clarification. +Avoid mentioning that you obtained the information from the context. +And answer according to the language of the user's question. + +Given the context information, answer the query. +Query: [query]""" + #################################### # Transcribe #################################### diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 5819badbd..78c220b6f 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -58,6 +58,63 @@ export const updateChunkParams = async (token: string, size: number, overlap: nu return res; }; +export const getRAGTemplate = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/template`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateRAGTemplate = async (token: string, template: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/template/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + template: template + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { const data = new FormData(); data.append('file', file); diff --git a/src/lib/utils/rag/index.ts b/src/lib/utils/rag/index.ts index 6b219ef20..ba1f29f88 100644 --- a/src/lib/utils/rag/index.ts +++ b/src/lib/utils/rag/index.ts @@ -1,17 +1,21 @@ -export const RAGTemplate = (context: string, query: string) => { - let template = `Use the following context as your learned knowledge, inside XML tags. - - [context] - - - When answer to user: - - If you don't know, just say that you don't know. - - If you don't know when you are not sure, ask for clarification. - Avoid mentioning that you obtained the information from the context. - And answer according to the language of the user's question. - - Given the context information, answer the query. - Query: [query]`; +import { getRAGTemplate } from '$lib/apis/rag'; + +export const RAGTemplate = async (token: string, context: string, query: string) => { + let template = await getRAGTemplate(token).catch(() => { + return `Use the following context as your learned knowledge, inside XML tags. + + [context] + + + When answer to user: + - If you don't know, just say that you don't know. + - If you don't know when you are not sure, ask for clarification. + Avoid mentioning that you obtained the information from the context. + And answer according to the language of the user's question. + + Given the context information, answer the query. + Query: [query]`; + }); template = template.replace(/\[context\]/g, context); template = template.replace(/\[query\]/g, query); diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 604cb544d..1d91a6144 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -266,7 +266,11 @@ console.log(contextString); - history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].raContent = await RAGTemplate( + localStorage.token, + contextString, + query + ); history.messages[parentId].contexts = relevantContexts; await tick(); processing = ''; diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index aab03d74f..b719ebf2b 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -280,7 +280,11 @@ console.log(contextString); - history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].raContent = await RAGTemplate( + localStorage.token, + contextString, + query + ); history.messages[parentId].contexts = relevantContexts; await tick(); processing = '';