feat: show RAG query results as citations

This commit is contained in:
Jun Siang Cheah 2024-05-06 21:14:51 +08:00 committed by Timothy J. Baek
parent ba09fcd548
commit 0872bea790
7 changed files with 234 additions and 42 deletions

View File

@ -320,11 +320,19 @@ def rag_messages(
extracted_collections.extend(collection) extracted_collections.extend(collection)
context_string = "" context_string = ""
citations = []
for context in relevant_contexts: for context in relevant_contexts:
try: try:
if "documents" in context: if "documents" in context:
items = [item for item in context["documents"][0] if item is not None] items = [item for item in context["documents"][0] if item is not None]
context_string += "\n\n".join(items) context_string += "\n\n".join(items)
if "metadatas" in context:
citations.append(
{
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context_string = context_string.strip() context_string = context_string.strip()
@ -355,7 +363,7 @@ def rag_messages(
messages[last_user_message_idx] = new_user_message messages[last_user_message_idx] = new_user_message
return messages return messages, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):

View File

@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
@ -102,6 +102,8 @@ origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware): class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False
if request.method == "POST" and ( if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path "/api/chat" in request.url.path or "/chat/completions" in request.url.path
): ):
@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones # Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification # data["modified"] = True # Example modification
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"] = rag_messages( data["messages"], citations = rag_messages(
docs=data["docs"], docs=data["docs"],
messages=data["messages"], messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE, template=rag_app.state.RAG_TEMPLATE,
@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
) )
del data["docs"] del data["docs"]
log.debug(f"data['messages']: {data['messages']}") log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
)
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
] ]
response = await call_next(request) response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async for data in original_generator:
yield data
app.add_middleware(RAGMiddleware) app.add_middleware(RAGMiddleware)

View File

@ -4,6 +4,8 @@ import type { ParsedEvent } from 'eventsource-parser';
type TextStreamUpdate = { type TextStreamUpdate = {
done: boolean; done: boolean;
value: string; value: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
citations?: any;
}; };
// createOpenAITextStream takes a responseBody with a SSE response, // createOpenAITextStream takes a responseBody with a SSE response,
@ -45,6 +47,11 @@ async function* openAIStreamToIterator(
const parsedData = JSON.parse(data); const parsedData = JSON.parse(data);
console.log(parsedData); console.log(parsedData);
if (parsedData.citations) {
yield { done: false, value: '', citations: parsedData.citations };
continue;
}
yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' }; yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' };
} catch (e) { } catch (e) {
console.error('Error extracting delta from SSE event:', e); console.error('Error extracting delta from SSE event:', e);
@ -62,6 +69,10 @@ async function* streamLargeDeltasAsRandomChunks(
yield textStreamUpdate; yield textStreamUpdate;
return; return;
} }
if (textStreamUpdate.citations) {
yield textStreamUpdate;
continue;
}
let content = textStreamUpdate.value; let content = textStreamUpdate.value;
if (content.length < 5) { if (content.length < 5) {
yield { done: false, value: content }; yield { done: false, value: content };

View File

@ -0,0 +1,75 @@
<script lang="ts">
import { getContext, onMount, tick } from 'svelte';
import Modal from '$lib/components/common/Modal.svelte';
const i18n = getContext('i18n');
export let show = false;
export let citation: any[];
let mergedDocuments = [];
onMount(async () => {
console.log(citation);
// Merge the document with its metadata
mergedDocuments = citation.document?.map((c, i) => {
return {
document: c,
metadata: citation.metadata?.[i]
};
});
console.log(mergedDocuments);
});
</script>
<Modal size="lg" bind:show>
<div>
<div class=" flex justify-between dark:text-gray-300 px-5 py-4">
<div class=" text-lg font-medium self-center capitalize">
{$i18n.t('Citation')}
</div>
<button
class="self-center"
on:click={() => {
show = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
<hr class=" dark:border-gray-850" />
<div class="flex flex-col w-full px-5 py-4 dark:text-gray-200 overflow-y-scroll max-h-[22rem]">
{#each mergedDocuments as document}
<!-- Source from document.metadata.source -->
<div class="flex flex-col w-full">
<div class="text-lg font-medium dark:text-gray-300">
{$i18n.t('Source')}
</div>
<div class="text-sm dark:text-gray-400">
{document.metadata.source}
</div>
</div>
<!-- Content from document.document.content -->
<div class="flex flex-col w-full">
<div class="text-lg font-medium dark:text-gray-300">
{$i18n.t('Content')}
</div>
<pre class="text-sm dark:text-gray-400">
{document.document}
</pre>
</div>
<hr class=" dark:border-gray-850" />
{/each}
</div>
</div>
</Modal>

View File

@ -32,6 +32,7 @@
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import RateComment from './RateComment.svelte'; import RateComment from './RateComment.svelte';
import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
export let modelfiles = []; export let modelfiles = [];
export let message; export let message;
@ -65,6 +66,8 @@
let showRateComment = false; let showRateComment = false;
let showCitations = {};
$: tokens = marked.lexer(sanitizeResponseContent(message.content)); $: tokens = marked.lexer(sanitizeResponseContent(message.content));
const renderer = new marked.Renderer(); const renderer = new marked.Renderer();
@ -360,6 +363,48 @@
{/each} {/each}
</div> </div>
{/if} {/if}
{#if message.citations}
<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
{#each message.citations as citation}
<div>
<CitationsModal bind:show={showCitations[citation]} {citation} />
<button
class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left"
type="button"
on:click={() => {
showCitations[citation] = !showCitations[citation];
}}
>
<div class="p-2.5 bg-red-400 text-white rounded-lg">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="w-6 h-6"
>
<path
fill-rule="evenodd"
d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
clip-rule="evenodd"
/>
<path
d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
/>
</svg>
</div>
<div class="flex flex-col justify-center -space-y-0.5">
<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1">
{citation.metadata?.[0]?.source ?? 'N/A'}
</div>
<div class=" text-gray-500 text-sm">{$i18n.t('Document')}</div>
</div>
</button>
</div>
{/each}
</div>
{/if}
<div <div
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-ol:p-0 prose-li:-mb-4 whitespace-pre-line" class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-ol:p-0 prose-li:-mb-4 whitespace-pre-line"
@ -577,10 +622,11 @@
stroke-linejoin="round" stroke-linejoin="round"
class="w-4 h-4" class="w-4 h-4"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
><path
d="M14 9V5a3 3 0 0 0-3-3l-4 9v11h11.28a2 2 0 0 0 2-1.7l1.38-9a2 2 0 0 0-2-2.3zM7 22H4a2 2 0 0 1-2-2v-7a2 2 0 0 1 2-2h3"
/></svg
> >
<path
d="M14 9V5a3 3 0 0 0-3-3l-4 9v11h11.28a2 2 0 0 0 2-1.7l1.38-9a2 2 0 0 0-2-2.3zM7 22H4a2 2 0 0 1-2-2v-7a2 2 0 0 1 2-2h3"
/>
</svg>
</button> </button>
</Tooltip> </Tooltip>
@ -611,10 +657,11 @@
stroke-linejoin="round" stroke-linejoin="round"
class="w-4 h-4" class="w-4 h-4"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
><path
d="M10 15v4a3 3 0 0 0 3 3l4-9V2H5.72a2 2 0 0 0-2 1.7l-1.38 9a2 2 0 0 0 2 2.3zm7-13h2.67A2.31 2.31 0 0 1 22 4v7a2.31 2.31 0 0 1-2.33 2H17"
/></svg
> >
<path
d="M10 15v4a3 3 0 0 0 3 3l4-9V2H5.72a2 2 0 0 0-2 1.7l-1.38 9a2 2 0 0 0 2 2.3zm7-13h2.67A2.31 2.31 0 0 1 22 4v7a2.31 2.31 0 0 1-2.33 2H17"
/>
</svg>
</button> </button>
</Tooltip> </Tooltip>
{/if} {/if}
@ -637,35 +684,32 @@
fill="currentColor" fill="currentColor"
viewBox="0 0 24 24" viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
><style> >
<style>
.spinner_S1WN { .spinner_S1WN {
animation: spinner_MGfb 0.8s linear infinite; animation: spinner_MGfb 0.8s linear infinite;
animation-delay: -0.8s; animation-delay: -0.8s;
} }
.spinner_Km9P { .spinner_Km9P {
animation-delay: -0.65s; animation-delay: -0.65s;
} }
.spinner_JApP { .spinner_JApP {
animation-delay: -0.5s; animation-delay: -0.5s;
} }
@keyframes spinner_MGfb { @keyframes spinner_MGfb {
93.75%, 93.75%,
100% { 100% {
opacity: 0.2; opacity: 0.2;
} }
} }
</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle </style>
class="spinner_S1WN spinner_Km9P" <circle class="spinner_S1WN" cx="4" cy="12" r="3" />
cx="12" <circle class="spinner_S1WN spinner_Km9P" cx="12" cy="12" r="3" />
cy="12" <circle class="spinner_S1WN spinner_JApP" cx="20" cy="12" r="3" />
r="3" </svg>
/><circle
class="spinner_S1WN spinner_JApP"
cx="20"
cy="12"
r="3"
/></svg
>
{:else if speaking} {:else if speaking}
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
@ -718,35 +762,32 @@
fill="currentColor" fill="currentColor"
viewBox="0 0 24 24" viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
><style> >
<style>
.spinner_S1WN { .spinner_S1WN {
animation: spinner_MGfb 0.8s linear infinite; animation: spinner_MGfb 0.8s linear infinite;
animation-delay: -0.8s; animation-delay: -0.8s;
} }
.spinner_Km9P { .spinner_Km9P {
animation-delay: -0.65s; animation-delay: -0.65s;
} }
.spinner_JApP { .spinner_JApP {
animation-delay: -0.5s; animation-delay: -0.5s;
} }
@keyframes spinner_MGfb { @keyframes spinner_MGfb {
93.75%, 93.75%,
100% { 100% {
opacity: 0.2; opacity: 0.2;
} }
} }
</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle </style>
class="spinner_S1WN spinner_Km9P" <circle class="spinner_S1WN" cx="4" cy="12" r="3" />
cx="12" <circle class="spinner_S1WN spinner_Km9P" cx="12" cy="12" r="3" />
cy="12" <circle class="spinner_S1WN spinner_JApP" cx="20" cy="12" r="3" />
r="3" </svg>
/><circle
class="spinner_S1WN spinner_JApP"
cx="20"
cy="12"
r="3"
/></svg
>
{:else} {:else}
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"

View File

@ -366,7 +366,8 @@
}, },
format: $settings.requestFormat ?? undefined, format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
docs: docs.length > 0 ? docs : undefined docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
}); });
if (res && res.ok) { if (res && res.ok) {
@ -401,6 +402,11 @@
console.log(line); console.log(line);
let data = JSON.parse(line); let data = JSON.parse(line);
if ('citations' in data) {
responseMessage.citations = data.citations;
continue;
}
if ('detail' in data) { if ('detail' in data) {
throw data; throw data;
} }
@ -598,7 +604,8 @@
num_ctx: $settings?.options?.num_ctx ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined, max_tokens: $settings?.options?.num_predict ?? undefined,
docs: docs.length > 0 ? docs : undefined docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
}, },
model?.source?.toLowerCase() === 'litellm' model?.source?.toLowerCase() === 'litellm'
? `${LITELLM_API_BASE_URL}/v1` ? `${LITELLM_API_BASE_URL}/v1`
@ -614,7 +621,7 @@
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) { for await (const update of textStream) {
const { value, done } = update; const { value, done, citations } = update;
if (done || stopResponseFlag || _chatId !== $chatId) { if (done || stopResponseFlag || _chatId !== $chatId) {
responseMessage.done = true; responseMessage.done = true;
messages = messages; messages = messages;
@ -626,6 +633,11 @@
break; break;
} }
if (citations) {
responseMessage.citations = citations;
continue;
}
if (responseMessage.content == '' && value == '\n') { if (responseMessage.content == '' && value == '\n') {
continue; continue;
} else { } else {

View File

@ -378,7 +378,8 @@
}, },
format: $settings.requestFormat ?? undefined, format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
docs: docs.length > 0 ? docs : undefined docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
}); });
if (res && res.ok) { if (res && res.ok) {
@ -413,6 +414,11 @@
console.log(line); console.log(line);
let data = JSON.parse(line); let data = JSON.parse(line);
if ('citations' in data) {
responseMessage.citations = data.citations;
continue;
}
if ('detail' in data) { if ('detail' in data) {
throw data; throw data;
} }
@ -610,7 +616,8 @@
num_ctx: $settings?.options?.num_ctx ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined, max_tokens: $settings?.options?.num_predict ?? undefined,
docs: docs.length > 0 ? docs : undefined docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
}, },
model?.source?.toLowerCase() === 'litellm' model?.source?.toLowerCase() === 'litellm'
? `${LITELLM_API_BASE_URL}/v1` ? `${LITELLM_API_BASE_URL}/v1`
@ -626,7 +633,7 @@
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) { for await (const update of textStream) {
const { value, done } = update; const { value, done, citations } = update;
if (done || stopResponseFlag || _chatId !== $chatId) { if (done || stopResponseFlag || _chatId !== $chatId) {
responseMessage.done = true; responseMessage.done = true;
messages = messages; messages = messages;
@ -638,6 +645,11 @@
break; break;
} }
if (citations) {
responseMessage.citations = citations;
continue;
}
if (responseMessage.content == '' && value == '\n') { if (responseMessage.content == '' && value == '\n') {
continue; continue;
} else { } else {