From 64ed0d10897695c4753cd3e77dd7f7cbba2da3e5 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 6 May 2024 15:49:00 -0700 Subject: [PATCH] refac: include source name to citation --- backend/apps/rag/utils.py | 37 ++++++-------- .../chat/Messages/CitationsModal.svelte | 4 +- .../chat/Messages/ResponseMessage.svelte | 50 ++++++++----------- 3 files changed, 38 insertions(+), 53 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 944abb219..1b3694fdb 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -271,14 +271,14 @@ def rag_messages( for doc in docs: context = None - collection = doc.get("collection_name") - if collection: - collection = [collection] - else: - collection = doc.get("collection_names", []) + collection_names = ( + doc["collection_names"] + if doc["type"] == "collection" + else [doc["collection_name"]] + ) - collection = set(collection).difference(extracted_collections) - if not collection: + collection_names = set(collection_names).difference(extracted_collections) + if not collection_names: log.debug(f"skipping {doc} as it has already been extracted") continue @@ -288,11 +288,7 @@ def rag_messages( else: if hybrid_search: context = query_collection_with_hybrid_search( - collection_names=( - doc["collection_names"] - if doc["type"] == "collection" - else [doc["collection_name"]] - ), + collection_names=collection_names, query=query, embedding_function=embedding_function, k=k, @@ -301,11 +297,7 @@ def rag_messages( ) else: context = query_collection( - collection_names=( - doc["collection_names"] - if doc["type"] == "collection" - else [doc["collection_name"]] - ), + collection_names=collection_names, query=query, embedding_function=embedding_function, k=k, @@ -315,9 +307,9 @@ def rag_messages( context = None if context: - relevant_contexts.append(context) + relevant_contexts.append({**context, "source": doc}) - extracted_collections.extend(collection) + extracted_collections.extend(collection_names) context_string = "" @@ -325,11 +317,14 @@ def rag_messages( for context in relevant_contexts: try: if "documents" in context: - items = [item for item in context["documents"][0] if item is not None] - context_string += "\n\n".join(items) + context_string += "\n\n".join( + [text for text in context["documents"][0] if text is not None] + ) + if "metadatas" in context: citations.append( { + "source": context["source"], "document": context["documents"][0], "metadata": context["metadatas"][0], } diff --git a/src/lib/components/chat/Messages/CitationsModal.svelte b/src/lib/components/chat/Messages/CitationsModal.svelte index 8a030b7b8..e6d171d0d 100644 --- a/src/lib/components/chat/Messages/CitationsModal.svelte +++ b/src/lib/components/chat/Messages/CitationsModal.svelte @@ -10,10 +10,10 @@ let mergedDocuments = []; onMount(async () => { - console.log(citation); // Merge the document with its metadata mergedDocuments = citation.document?.map((c, i) => { return { + source: citation.source, document: c, metadata: citation.metadata?.[i] }; @@ -54,7 +54,7 @@ {$i18n.t('Source')}
- {document.metadata?.source ?? $i18n.t('No source available')} + {document.source?.name ?? $i18n.t('No source available')}
diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index b7a4f259c..d604ba3c9 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -66,9 +66,8 @@ let showRateComment = false; - let showCitations = {}; // Backend returns a list of citations per collection, we flatten it to citations per source - let flattenedCitations = {}; + let citations = {}; $: tokens = marked.lexer(sanitizeResponseContent(message.content)); @@ -137,27 +136,21 @@ } if (message.citations) { - for (const citation of message.citations) { - const zipped = (citation?.document ?? []).map(function (document, index) { - return [document, citation.metadata?.[index]]; + message.citations.forEach((citation) => { + citation.document.forEach((document, index) => { + const metadata = citation.metadata?.[index]; + const source = citation?.source?.name ?? metadata?.source ?? 'N/A'; + + citations[source] = citations[source] || { + source: citation.source, + document: [], + metadata: [] + }; + + citations[source].document.push(document); + citations[source].metadata.push(metadata); }); - - for (const [document, metadata] of zipped) { - const source = metadata?.source ?? 'N/A'; - if (source in flattenedCitations) { - flattenedCitations[source].document.push(document); - flattenedCitations[source].metadata.push(metadata); - } else { - flattenedCitations[source] = { - document: [document], - metadata: [metadata] - }; - } - } - } - - console.log(flattenedCitations); - console.log(Object.keys(flattenedCitations)); + }); } }; @@ -474,15 +467,12 @@
- {#if Object.keys(flattenedCitations).length > 0} + {#if Object.keys(citations).length > 0}
- {#each [...Object.keys(flattenedCitations)] as source, idx} - + {#each Object.keys(citations) as source, idx} +
@@ -492,10 +482,10 @@
{/each}