diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index b1142e855..8efa1a9f7 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -320,11 +320,19 @@ def rag_messages( extracted_collections.extend(collection) context_string = "" + citations = [] for context in relevant_contexts: try: if "documents" in context: items = [item for item in context["documents"][0] if item is not None] context_string += "\n\n".join(items) + if "metadatas" in context: + citations.append( + { + "document": context["documents"][0], + "metadata": context["metadatas"][0], + } + ) except Exception as e: log.exception(e) context_string = context_string.strip() @@ -355,7 +363,7 @@ def rag_messages( messages[last_user_message_idx] = new_user_message - return messages + return messages, citations def get_model_path(model: str, update_model: bool = False): diff --git a/backend/main.py b/backend/main.py index e36b8296a..dc2175ad5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware - +from starlette.responses import StreamingResponse from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app @@ -102,6 +102,8 @@ origins = ["*"] class RAGMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): + return_citations = False + if request.method == "POST" and ( "/api/chat" in request.url.path or "/chat/completions" in request.url.path ): @@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + return_citations = data.get("citations", False) + if "citations" in data: + del data["citations"] + # Example: Add a new key-value pair or modify existing ones # data["modified"] = True # Example modification if "docs" in data: data = {**data} - data["messages"] = rag_messages( + data["messages"], citations = rag_messages( docs=data["docs"], messages=data["messages"], template=rag_app.state.RAG_TEMPLATE, @@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware): ) del data["docs"] - log.debug(f"data['messages']: {data['messages']}") + log.debug( + f"data['messages']: {data['messages']}, citations: {citations}" + ) modified_body_bytes = json.dumps(data).encode("utf-8") @@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware): ] response = await call_next(request) + + if return_citations: + # Inject the citations into the response + if isinstance(response, StreamingResponse): + # If it's a streaming response, inject it as SSE event or NDJSON line + content_type = response.headers.get("Content-Type") + if "text/event-stream" in content_type: + return StreamingResponse( + self.openai_stream_wrapper(response.body_iterator, citations), + ) + if "application/x-ndjson" in content_type: + return StreamingResponse( + self.ollama_stream_wrapper(response.body_iterator, citations), + ) + return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} + async def openai_stream_wrapper(self, original_generator, citations): + yield f"data: {json.dumps({'citations': citations})}\n\n" + async for data in original_generator: + yield data + + async def ollama_stream_wrapper(self, original_generator, citations): + yield f"{json.dumps({'citations': citations})}\n" + async for data in original_generator: + yield data + app.add_middleware(RAGMiddleware) diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index a72dbe47d..0e87c2524 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -4,6 +4,8 @@ import type { ParsedEvent } from 'eventsource-parser'; type TextStreamUpdate = { done: boolean; value: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + citations?: any; }; // createOpenAITextStream takes a responseBody with a SSE response, @@ -45,6 +47,11 @@ async function* openAIStreamToIterator( const parsedData = JSON.parse(data); console.log(parsedData); + if (parsedData.citations) { + yield { done: false, value: '', citations: parsedData.citations }; + continue; + } + yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' }; } catch (e) { console.error('Error extracting delta from SSE event:', e); @@ -62,6 +69,10 @@ async function* streamLargeDeltasAsRandomChunks( yield textStreamUpdate; return; } + if (textStreamUpdate.citations) { + yield textStreamUpdate; + continue; + } let content = textStreamUpdate.value; if (content.length < 5) { yield { done: false, value: content }; diff --git a/src/lib/components/chat/Messages/CitationsModal.svelte b/src/lib/components/chat/Messages/CitationsModal.svelte new file mode 100644 index 000000000..873fab283 --- /dev/null +++ b/src/lib/components/chat/Messages/CitationsModal.svelte @@ -0,0 +1,75 @@ + + + +
+
+
+ {$i18n.t('Citation')} +
+ +
+
+ +
+ {#each mergedDocuments as document} + +
+
+ {$i18n.t('Source')} +
+
+ {document.metadata.source} +
+
+ +
+
+ {$i18n.t('Content')} +
+
+						{document.document}
+					
+
+
+ {/each} +
+
+
diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 4d87f929f..d7da104fe 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -32,6 +32,7 @@ import { WEBUI_BASE_URL } from '$lib/constants'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import RateComment from './RateComment.svelte'; + import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte'; export let modelfiles = []; export let message; @@ -65,6 +66,8 @@ let showRateComment = false; + let showCitations = {}; + $: tokens = marked.lexer(sanitizeResponseContent(message.content)); const renderer = new marked.Renderer(); @@ -360,6 +363,48 @@ {/each} {/if} + {#if message.citations} +
+ {#each message.citations as citation} +
+ + +
+ {/each} +
+ {/if}
+ + @@ -611,10 +657,11 @@ stroke-linejoin="round" class="w-4 h-4" xmlns="http://www.w3.org/2000/svg" - > + + {/if} @@ -637,35 +684,32 @@ fill="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" - > + + + + + {:else if speaking} + + + + + {:else} 0 ? docs : undefined + docs: docs.length > 0 ? docs : undefined, + citations: docs.length > 0 }); if (res && res.ok) { @@ -401,6 +402,11 @@ console.log(line); let data = JSON.parse(line); + if ('citations' in data) { + responseMessage.citations = data.citations; + continue; + } + if ('detail' in data) { throw data; } @@ -598,7 +604,8 @@ num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, max_tokens: $settings?.options?.num_predict ?? undefined, - docs: docs.length > 0 ? docs : undefined + docs: docs.length > 0 ? docs : undefined, + citations: docs.length > 0 }, model?.source?.toLowerCase() === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` @@ -614,7 +621,7 @@ const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); for await (const update of textStream) { - const { value, done } = update; + const { value, done, citations } = update; if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; messages = messages; @@ -626,6 +633,11 @@ break; } + if (citations) { + responseMessage.citations = citations; + continue; + } + if (responseMessage.content == '' && value == '\n') { continue; } else { diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index eab368a11..ccf85317e 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -378,7 +378,8 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, - docs: docs.length > 0 ? docs : undefined + docs: docs.length > 0 ? docs : undefined, + citations: docs.length > 0 }); if (res && res.ok) { @@ -413,6 +414,11 @@ console.log(line); let data = JSON.parse(line); + if ('citations' in data) { + responseMessage.citations = data.citations; + continue; + } + if ('detail' in data) { throw data; } @@ -610,7 +616,8 @@ num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, max_tokens: $settings?.options?.num_predict ?? undefined, - docs: docs.length > 0 ? docs : undefined + docs: docs.length > 0 ? docs : undefined, + citations: docs.length > 0 }, model?.source?.toLowerCase() === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` @@ -626,7 +633,7 @@ const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); for await (const update of textStream) { - const { value, done } = update; + const { value, done, citations } = update; if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; messages = messages; @@ -638,6 +645,11 @@ break; } + if (citations) { + responseMessage.citations = citations; + continue; + } + if (responseMessage.content == '' && value == '\n') { continue; } else {