From 50f7b20ac293b2f17de0f382bbcd7ae5d0f89349 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 1 Feb 2024 13:35:41 -0800 Subject: [PATCH] refac --- backend/apps/rag/main.py | 46 +++++++++++++++++++++++----- src/lib/apis/rag/index.ts | 31 +++++++++---------- src/routes/(app)/+page.svelte | 36 ++++++++++++---------- src/routes/(app)/c/[id]/+page.svelte | 36 ++++++++++++---------- 4 files changed, 91 insertions(+), 58 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 85bc995ae..eec3dfa23 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -10,6 +10,7 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware import os, shutil +from typing import List # from chromadb.utils import embedding_functions @@ -96,19 +97,22 @@ async def get_status(): return {"status": True} -@app.get("/query/{collection_name}") +class QueryCollectionForm(BaseModel): + collection_name: str + query: str + k: Optional[int] = 4 + + +@app.post("/query/collection") def query_collection( - collection_name: str, - query: str, - k: Optional[int] = 4, + form_data: QueryCollectionForm, user=Depends(get_current_user), ): try: collection = CHROMA_CLIENT.get_collection( - name=collection_name, + name=form_data.collection_name, ) - result = collection.query(query_texts=[query], n_results=k) - + result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: print(e) @@ -118,6 +122,34 @@ def query_collection( ) +class QueryCollectionsForm(BaseModel): + collection_names: List[str] + query: str + k: Optional[int] = 4 + + +@app.post("/query/collections") +def query_collections( + form_data: QueryCollectionsForm, + user=Depends(get_current_user), +): + results = [] + + for collection_name in form_data.collection_names: + try: + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query( + query_texts=[form_data.query], n_results=form_data.k + ) + results.append(result) + except: + pass + + return results + + @app.post("/web") def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 8a2b8cb4e..e2656943e 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string export const queryVectorDB = async ( token: string, - collection_name: string, + collection_names: string[], query: string, k: number ) => { let error = null; - const searchParams = new URLSearchParams(); - searchParams.set('query', query); - if (k) { - searchParams.set('k', k.toString()); - } - - const res = await fetch( - `${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, - { - method: 'GET', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - } - ) + const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_names: collection_names, + query: query, + k: k + }) + }) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index eebf7743a..657c9d9f2 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -232,26 +232,28 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); - }) - ); - relevantContexts = relevantContexts.filter((context) => context); + let relevantContexts = await queryVectorDB( + localStorage.token, + docs.map((d) => d.collection_name), + query, + 4 + ).catch((error) => { + console.log(error); + return null; + }); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + if (relevantContexts) { + relevantContexts = relevantContexts.filter((context) => context); - console.log(contextString); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; + console.log(contextString); + + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; + } await tick(); processing = ''; } diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index c161435df..5509019a7 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -246,26 +246,28 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); - }) - ); - relevantContexts = relevantContexts.filter((context) => context); + let relevantContexts = await queryVectorDB( + localStorage.token, + docs.map((d) => d.collection_name), + query, + 4 + ).catch((error) => { + console.log(error); + return null; + }); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + if (relevantContexts) { + relevantContexts = relevantContexts.filter((context) => context); - console.log(contextString); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; + console.log(contextString); + + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; + } await tick(); processing = ''; }