diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 10d6ff9a7..ba53a1c89 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -902,10 +902,11 @@ def process_file( Document( page_content=form_data.content, metadata={ - "name": file.meta.get("name", file.filename), + **file.meta, + "name": file.filename, "created_by": file.user_id, "file_id": file.id, - **file.meta, + "source": file.filename, }, ) ] @@ -932,10 +933,11 @@ def process_file( Document( page_content=file.data.get("content", ""), metadata={ - "name": file.meta.get("name", file.filename), + **file.meta, + "name": file.filename, "created_by": file.user_id, "file_id": file.id, - **file.meta, + "source": file.filename, }, ) ] @@ -955,15 +957,30 @@ def process_file( docs = loader.load( file.filename, file.meta.get("content_type"), file_path ) + + docs = [ + Document( + page_content=doc.page_content, + metadata={ + **doc.metadata, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + for doc in docs + ] else: docs = [ Document( page_content=file.data.get("content", ""), metadata={ + **file.meta, "name": file.filename, "created_by": file.user_id, "file_id": file.id, - **file.meta, + "source": file.filename, }, ) ] diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 6d87c98e3..e4e36fbfd 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -307,7 +307,7 @@ def get_embedding_function( return lambda query: generate_multiple(query, func) -def get_rag_context( +def get_sources_from_files( files, queries, embedding_function, @@ -387,43 +387,24 @@ def get_rag_context( del file["data"] relevant_contexts.append({**context, "file": file}) - contexts = [] - citations = [] + sources = [] for context in relevant_contexts: try: if "documents" in context: - file_names = list( - set( - [ - metadata["name"] - for metadata in context["metadatas"][0] - if metadata is not None and "name" in metadata - ] - ) - ) - contexts.append( - ((", ".join(file_names) + ":\n\n") if file_names else "") - + "\n\n".join( - [text for text in context["documents"][0] if text is not None] - ) - ) - if "metadatas" in context: - citation = { + source = { "source": context["file"], "document": context["documents"][0], "metadata": context["metadatas"][0], } if "distances" in context and context["distances"]: - citation["distances"] = context["distances"][0] - citations.append(citation) + source["distances"] = context["distances"][0] + + sources.append(source) except Exception as e: log.exception(e) - print("contexts", contexts) - print("citations", citations) - - return contexts, citations + return sources def get_model_path(model: str, update_model: bool = False): diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index b8695eb67..e7459a15f 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -56,7 +56,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): FileForm( **{ "id": id, - "filename": filename, + "filename": name, "path": file_path, "meta": { "name": name, diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d62ef158e..797c9622a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -49,7 +49,7 @@ from open_webui.apps.openai.main import ( get_all_models_responses as get_openai_models_responses, ) from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_rag_context, rag_template +from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template from open_webui.apps.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, @@ -380,8 +380,7 @@ async def chat_completion_tools_handler( return body, {} skip_files = False - contexts = [] - citations = [] + sources = [] task_model_id = get_task_model_id( body["model"], @@ -465,24 +464,37 @@ async def chat_completion_tools_handler( print(tools[tool_function_name]["citation"]) - if tools[tool_function_name]["citation"]: - citations.append( - { - "source": { - "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" - }, - "document": [tool_output], - "metadata": [{"source": tool_function_name}], - } - ) - else: - citations.append({}) - - if tools[tool_function_name]["file_handler"]: - skip_files = True - if isinstance(tool_output, str): - contexts.append(tool_output) + if tools[tool_function_name]["citation"]: + sources.append( + { + "source": { + "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + }, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + else: + sources.append( + { + "source": {}, + "document": [tool_output], + "metadata": [ + { + "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" + } + ], + } + ) + + if tools[tool_function_name]["file_handler"]: + skip_files = True + except Exception as e: log.exception(f"Error: {e}") content = None @@ -490,19 +502,18 @@ async def chat_completion_tools_handler( log.exception(f"Error: {e}") content = None - log.debug(f"tool_contexts: {contexts} {citations}") + log.debug(f"tool_contexts: {sources}") if skip_files and "files" in body.get("metadata", {}): del body["metadata"]["files"] - return body, {"contexts": contexts, "citations": citations} + return body, {"sources": sources} async def chat_completion_files_handler( body: dict, user: UserModel ) -> tuple[dict, dict[str, list]]: - contexts = [] - citations = [] + sources = [] try: queries_response = await generate_queries( @@ -530,7 +541,7 @@ async def chat_completion_files_handler( print(f"{queries=}") if files := body.get("metadata", {}).get("files", None): - contexts, citations = get_rag_context( + sources = get_sources_from_files( files=files, queries=queries, embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, @@ -540,9 +551,8 @@ async def chat_completion_files_handler( hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) - log.debug(f"rag_contexts: {contexts}, citations: {citations}") - - return body, {"contexts": contexts, "citations": citations} + log.debug(f"rag_contexts:sources: {sources}") + return body, {"sources": sources} def is_chat_completion_request(request): @@ -643,8 +653,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): # Initialize data_items to store additional data to be sent to the client # Initialize contexts and citation data_items = [] - contexts = [] - citations = [] + sources = [] try: body, flags = await chat_completion_filter_functions_handler( @@ -670,32 +679,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): body, flags = await chat_completion_tools_handler( body, user, models, extra_params ) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) + sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) try: body, flags = await chat_completion_files_handler(body, user) - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) + sources.extend(flags.get("sources", [])) except Exception as e: log.exception(e) # If context is not empty, insert it into the messages - if len(contexts) > 0: + if len(sources) > 0: context_string = "" - for context_idx, context in enumerate(contexts): - print(context) - source_id = citations[context_idx].get("source", {}).get("name", "") + for source_idx, source in enumerate(sources): + source_id = source.get("source", {}).get("name", "") - print(f"\n\n\n\n{source_id}\n\n\n\n") - if source_id: - context_string += f"\n" - else: - context_string += ( - f"\n" - ) + if "document" in source: + for doc_idx, doc_context in enumerate(source["document"]): + metadata = source.get("metadata") + + if metadata: + doc_source_id = metadata[doc_idx].get("source", source_id) + + if source_id: + context_string += f"\n" + else: + # If there is no source_id, then do not include the source_id tag + context_string += f"\n" context_string = context_string.strip() prompt = get_last_user_message(body["messages"]) @@ -728,8 +739,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) + sources = [ + source for source in sources if source.get("source", {}).get("name", "") + ] + if len(sources) > 0: + data_items.append({"sources": sources}) modified_body_bytes = json.dumps(body).encode("utf-8") # Replace the request body with the modified one diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index a8249abe0..54804385d 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -5,7 +5,7 @@ type TextStreamUpdate = { done: boolean; value: string; // eslint-disable-next-line @typescript-eslint/no-explicit-any - citations?: any; + sources?: any; // eslint-disable-next-line @typescript-eslint/no-explicit-any selectedModelId?: any; error?: any; @@ -67,8 +67,8 @@ async function* openAIStreamToIterator( break; } - if (parsedData.citations) { - yield { done: false, value: '', citations: parsedData.citations }; + if (parsedData.sources) { + yield { done: false, value: '', sources: parsedData.sources }; continue; } @@ -98,7 +98,7 @@ async function* streamLargeDeltasAsRandomChunks( yield textStreamUpdate; return; } - if (textStreamUpdate.citations) { + if (textStreamUpdate.sources) { yield textStreamUpdate; continue; } diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 8a1ef2d91..ae584ba8f 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -236,10 +236,10 @@ message.code_executions = message.code_executions; } else { // Regular citation. - if (message?.citations) { - message.citations.push(data); + if (message?.sources) { + message.sources.push(data); } else { - message.citations = [data]; + message.sources = [data]; } } } else if (type === 'message') { @@ -664,7 +664,7 @@ content: m.content, info: m.info ? m.info : undefined, timestamp: m.timestamp, - ...(m.citations ? { citations: m.citations } : {}) + ...(m.sources ? { sources: m.sources } : {}) })), chat_id: chatId, session_id: $socket?.id, @@ -718,7 +718,7 @@ content: m.content, info: m.info ? m.info : undefined, timestamp: m.timestamp, - ...(m.citations ? { citations: m.citations } : {}) + ...(m.sources ? { sources: m.sources } : {}) })), ...(event ? { event: event } : {}), chat_id: chatId, @@ -1278,8 +1278,8 @@ console.log(line); let data = JSON.parse(line); - if ('citations' in data) { - responseMessage.citations = data.citations; + if ('sources' in data) { + responseMessage.sources = data.sources; // Only remove status if it was initially set if (model?.info?.meta?.knowledge ?? false) { responseMessage.statusHistory = responseMessage.statusHistory.filter( @@ -1632,7 +1632,7 @@ const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); for await (const update of textStream) { - const { value, done, citations, selectedModelId, error, usage } = update; + const { value, done, sources, selectedModelId, error, usage } = update; if (error) { await handleOpenAIError(error, null, model, responseMessage); break; @@ -1658,8 +1658,8 @@ continue; } - if (citations) { - responseMessage.citations = citations; + if (sources) { + responseMessage.sources = sources; // Only remove status if it was initially set if (model?.info?.meta?.knowledge ?? false) { responseMessage.statusHistory = responseMessage.statusHistory.filter( @@ -1938,7 +1938,7 @@ if (res && res.ok && res.body) { const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); for await (const update of textStream) { - const { value, done, citations, error, usage } = update; + const { value, done, sources, error, usage } = update; if (error || done) { break; } diff --git a/src/lib/components/chat/Messages/Citations.svelte b/src/lib/components/chat/Messages/Citations.svelte index 1e24518cd..57f7b1ab9 100644 --- a/src/lib/components/chat/Messages/Citations.svelte +++ b/src/lib/components/chat/Messages/Citations.svelte @@ -7,9 +7,9 @@ const i18n = getContext('i18n'); - export let citations = []; + export let sources = []; - let _citations = []; + let citations = []; let showPercentage = false; let showRelevance = true; @@ -17,8 +17,8 @@ let selectedCitation: any = null; let isCollapsibleOpen = false; - function calculateShowRelevance(citations: any[]) { - const distances = citations.flatMap((citation) => citation.distances ?? []); + function calculateShowRelevance(sources: any[]) { + const distances = sources.flatMap((citation) => citation.distances ?? []); const inRange = distances.filter((d) => d !== undefined && d >= -1 && d <= 1).length; const outOfRange = distances.filter((d) => d !== undefined && (d < -1 || d > 1)).length; @@ -36,29 +36,31 @@ return true; } - function shouldShowPercentage(citations: any[]) { - const distances = citations.flatMap((citation) => citation.distances ?? []); + function shouldShowPercentage(sources: any[]) { + const distances = sources.flatMap((citation) => citation.distances ?? []); return distances.every((d) => d !== undefined && d >= -1 && d <= 1); } $: { - _citations = citations.reduce((acc, citation) => { - if (Object.keys(citation).length === 0) { + citations = sources.reduce((acc, source) => { + if (Object.keys(source).length === 0) { return acc; } - citation.document.forEach((document, index) => { - const metadata = citation.metadata?.[index]; - const distance = citation.distances?.[index]; + source.document.forEach((document, index) => { + const metadata = source.metadata?.[index]; + const distance = source.distances?.[index]; + + // Within the same citation there could be multiple documents const id = metadata?.source ?? 'N/A'; - let source = citation?.source; + let _source = source?.source; if (metadata?.name) { - source = { ...source, name: metadata.name }; + _source = { ..._source, name: metadata.name }; } if (id.startsWith('http://') || id.startsWith('https://')) { - source = { ...source, name: id, url: id }; + _source = { ..._source, name: id, url: id }; } const existingSource = acc.find((item) => item.id === id); @@ -70,7 +72,7 @@ } else { acc.push({ id: id, - source: source, + source: _source, document: [document], metadata: metadata ? [metadata] : [], distances: distance !== undefined ? [distance] : undefined @@ -80,8 +82,8 @@ return acc; }, []); - showRelevance = calculateShowRelevance(_citations); - showPercentage = shouldShowPercentage(_citations); + showRelevance = calculateShowRelevance(citations); + showPercentage = shouldShowPercentage(citations); } @@ -92,11 +94,11 @@ {showRelevance} /> -{#if _citations.length > 0} +{#if citations.length > 0}