diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index e2656943e..08dff5bfa 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -66,13 +66,13 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string export const queryVectorDB = async ( token: string, - collection_names: string[], + collection_name: string, query: string, k: number ) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, { + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { method: 'POST', headers: { Accept: 'application/json', @@ -80,7 +80,7 @@ export const queryVectorDB = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - collection_names: collection_names, + collection_name: collection_name, query: query, k: k }) diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 657c9d9f2..956b6cb04 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -28,7 +28,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryVectorDB } from '$lib/apis/rag'; + import { queryCollection } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -232,28 +232,26 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await queryVectorDB( - localStorage.token, - docs.map((d) => d.collection_name), - query, - 4 - ).catch((error) => { - console.log(error); - return null; - }); + let relevantContexts = await Promise.all( + docs.map(async (doc) => { + return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + }) + ); + relevantContexts = relevantContexts.filter((context) => context); - if (relevantContexts) { - relevantContexts = relevantContexts.filter((context) => context); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + console.log(contextString); - console.log(contextString); - - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; - } + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; await tick(); processing = ''; } diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 5509019a7..fac8a01ce 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -29,7 +29,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryVectorDB } from '$lib/apis/rag'; + import { queryCollection } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -246,28 +246,26 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await queryVectorDB( - localStorage.token, - docs.map((d) => d.collection_name), - query, - 4 - ).catch((error) => { - console.log(error); - return null; - }); + let relevantContexts = await Promise.all( + docs.map(async (doc) => { + return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + }) + ); + relevantContexts = relevantContexts.filter((context) => context); - if (relevantContexts) { - relevantContexts = relevantContexts.filter((context) => context); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + console.log(contextString); - console.log(contextString); - - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; - } + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; await tick(); processing = ''; }