diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py
index b1142e855..8efa1a9f7 100644
--- a/backend/apps/rag/utils.py
+++ b/backend/apps/rag/utils.py
@@ -320,11 +320,19 @@ def rag_messages(
extracted_collections.extend(collection)
context_string = ""
+ citations = []
for context in relevant_contexts:
try:
if "documents" in context:
items = [item for item in context["documents"][0] if item is not None]
context_string += "\n\n".join(items)
+ if "metadatas" in context:
+ citations.append(
+ {
+ "document": context["documents"][0],
+ "metadata": context["metadatas"][0],
+ }
+ )
except Exception as e:
log.exception(e)
context_string = context_string.strip()
@@ -355,7 +363,7 @@ def rag_messages(
messages[last_user_message_idx] = new_user_message
- return messages
+ return messages, citations
def get_model_path(model: str, update_model: bool = False):
diff --git a/backend/main.py b/backend/main.py
index e36b8296a..dc2175ad5 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
-
+from starlette.responses import StreamingResponse
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
@@ -102,6 +102,8 @@ origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
+ return_citations = False
+
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
@@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
+ return_citations = data.get("citations", False)
+ if "citations" in data:
+ del data["citations"]
+
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
data = {**data}
- data["messages"] = rag_messages(
+ data["messages"], citations = rag_messages(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE,
@@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
)
del data["docs"]
- log.debug(f"data['messages']: {data['messages']}")
+ log.debug(
+ f"data['messages']: {data['messages']}, citations: {citations}"
+ )
modified_body_bytes = json.dumps(data).encode("utf-8")
@@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
]
response = await call_next(request)
+
+ if return_citations:
+ # Inject the citations into the response
+ if isinstance(response, StreamingResponse):
+ # If it's a streaming response, inject it as SSE event or NDJSON line
+ content_type = response.headers.get("Content-Type")
+ if "text/event-stream" in content_type:
+ return StreamingResponse(
+ self.openai_stream_wrapper(response.body_iterator, citations),
+ )
+ if "application/x-ndjson" in content_type:
+ return StreamingResponse(
+ self.ollama_stream_wrapper(response.body_iterator, citations),
+ )
+
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
+ async def openai_stream_wrapper(self, original_generator, citations):
+ yield f"data: {json.dumps({'citations': citations})}\n\n"
+ async for data in original_generator:
+ yield data
+
+ async def ollama_stream_wrapper(self, original_generator, citations):
+ yield f"{json.dumps({'citations': citations})}\n"
+ async for data in original_generator:
+ yield data
+
app.add_middleware(RAGMiddleware)
diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts
index a72dbe47d..0e87c2524 100644
--- a/src/lib/apis/streaming/index.ts
+++ b/src/lib/apis/streaming/index.ts
@@ -4,6 +4,8 @@ import type { ParsedEvent } from 'eventsource-parser';
type TextStreamUpdate = {
done: boolean;
value: string;
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ citations?: any;
};
// createOpenAITextStream takes a responseBody with a SSE response,
@@ -45,6 +47,11 @@ async function* openAIStreamToIterator(
const parsedData = JSON.parse(data);
console.log(parsedData);
+ if (parsedData.citations) {
+ yield { done: false, value: '', citations: parsedData.citations };
+ continue;
+ }
+
yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' };
} catch (e) {
console.error('Error extracting delta from SSE event:', e);
@@ -62,6 +69,10 @@ async function* streamLargeDeltasAsRandomChunks(
yield textStreamUpdate;
return;
}
+ if (textStreamUpdate.citations) {
+ yield textStreamUpdate;
+ continue;
+ }
let content = textStreamUpdate.value;
if (content.length < 5) {
yield { done: false, value: content };
diff --git a/src/lib/components/chat/Messages/CitationsModal.svelte b/src/lib/components/chat/Messages/CitationsModal.svelte
new file mode 100644
index 000000000..873fab283
--- /dev/null
+++ b/src/lib/components/chat/Messages/CitationsModal.svelte
@@ -0,0 +1,75 @@
+
+
+
+
+
+ {document.document}
+
+
+ {/each}
+