mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: citations -> sources
This commit is contained in:
@@ -902,10 +902,11 @@ def process_file(
|
||||
Document(
|
||||
page_content=form_data.content,
|
||||
metadata={
|
||||
"name": file.meta.get("name", file.filename),
|
||||
**file.meta,
|
||||
"name": file.filename,
|
||||
"created_by": file.user_id,
|
||||
"file_id": file.id,
|
||||
**file.meta,
|
||||
"source": file.filename,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -932,10 +933,11 @@ def process_file(
|
||||
Document(
|
||||
page_content=file.data.get("content", ""),
|
||||
metadata={
|
||||
"name": file.meta.get("name", file.filename),
|
||||
**file.meta,
|
||||
"name": file.filename,
|
||||
"created_by": file.user_id,
|
||||
"file_id": file.id,
|
||||
**file.meta,
|
||||
"source": file.filename,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -955,15 +957,30 @@ def process_file(
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
)
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={
|
||||
**doc.metadata,
|
||||
"name": file.filename,
|
||||
"created_by": file.user_id,
|
||||
"file_id": file.id,
|
||||
"source": file.filename,
|
||||
},
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
else:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=file.data.get("content", ""),
|
||||
metadata={
|
||||
**file.meta,
|
||||
"name": file.filename,
|
||||
"created_by": file.user_id,
|
||||
"file_id": file.id,
|
||||
**file.meta,
|
||||
"source": file.filename,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -307,7 +307,7 @@ def get_embedding_function(
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
|
||||
def get_rag_context(
|
||||
def get_sources_from_files(
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
@@ -387,43 +387,24 @@ def get_rag_context(
|
||||
del file["data"]
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
contexts = []
|
||||
citations = []
|
||||
sources = []
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
file_names = list(
|
||||
set(
|
||||
[
|
||||
metadata["name"]
|
||||
for metadata in context["metadatas"][0]
|
||||
if metadata is not None and "name" in metadata
|
||||
]
|
||||
)
|
||||
)
|
||||
contexts.append(
|
||||
((", ".join(file_names) + ":\n\n") if file_names else "")
|
||||
+ "\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
)
|
||||
)
|
||||
|
||||
if "metadatas" in context:
|
||||
citation = {
|
||||
source = {
|
||||
"source": context["file"],
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
if "distances" in context and context["distances"]:
|
||||
citation["distances"] = context["distances"][0]
|
||||
citations.append(citation)
|
||||
source["distances"] = context["distances"][0]
|
||||
|
||||
sources.append(source)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
print("contexts", contexts)
|
||||
print("citations", citations)
|
||||
|
||||
return contexts, citations
|
||||
return sources
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
@@ -56,7 +56,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
FileForm(
|
||||
**{
|
||||
"id": id,
|
||||
"filename": filename,
|
||||
"filename": name,
|
||||
"path": file_path,
|
||||
"meta": {
|
||||
"name": name,
|
||||
|
||||
@@ -49,7 +49,7 @@ from open_webui.apps.openai.main import (
|
||||
get_all_models_responses as get_openai_models_responses,
|
||||
)
|
||||
from open_webui.apps.retrieval.main import app as retrieval_app
|
||||
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
||||
from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template
|
||||
from open_webui.apps.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
@@ -380,8 +380,7 @@ async def chat_completion_tools_handler(
|
||||
return body, {}
|
||||
|
||||
skip_files = False
|
||||
contexts = []
|
||||
citations = []
|
||||
sources = []
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
@@ -465,24 +464,37 @@ async def chat_completion_tools_handler(
|
||||
|
||||
print(tools[tool_function_name]["citation"])
|
||||
|
||||
if tools[tool_function_name]["citation"]:
|
||||
citations.append(
|
||||
{
|
||||
"source": {
|
||||
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
},
|
||||
"document": [tool_output],
|
||||
"metadata": [{"source": tool_function_name}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
citations.append({})
|
||||
|
||||
if tools[tool_function_name]["file_handler"]:
|
||||
skip_files = True
|
||||
|
||||
if isinstance(tool_output, str):
|
||||
contexts.append(tool_output)
|
||||
if tools[tool_function_name]["citation"]:
|
||||
sources.append(
|
||||
{
|
||||
"source": {
|
||||
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
},
|
||||
"document": [tool_output],
|
||||
"metadata": [
|
||||
{
|
||||
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
sources.append(
|
||||
{
|
||||
"source": {},
|
||||
"document": [tool_output],
|
||||
"metadata": [
|
||||
{
|
||||
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tools[tool_function_name]["file_handler"]:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error: {e}")
|
||||
content = None
|
||||
@@ -490,19 +502,18 @@ async def chat_completion_tools_handler(
|
||||
log.exception(f"Error: {e}")
|
||||
content = None
|
||||
|
||||
log.debug(f"tool_contexts: {contexts} {citations}")
|
||||
log.debug(f"tool_contexts: {sources}")
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {"contexts": contexts, "citations": citations}
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
contexts = []
|
||||
citations = []
|
||||
sources = []
|
||||
|
||||
try:
|
||||
queries_response = await generate_queries(
|
||||
@@ -530,7 +541,7 @@ async def chat_completion_files_handler(
|
||||
print(f"{queries=}")
|
||||
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
contexts, citations = get_rag_context(
|
||||
sources = get_sources_from_files(
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
|
||||
@@ -540,9 +551,8 @@ async def chat_completion_files_handler(
|
||||
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
|
||||
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
||||
|
||||
return body, {"contexts": contexts, "citations": citations}
|
||||
log.debug(f"rag_contexts:sources: {sources}")
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
def is_chat_completion_request(request):
|
||||
@@ -643,8 +653,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
# Initialize data_items to store additional data to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
data_items = []
|
||||
contexts = []
|
||||
citations = []
|
||||
sources = []
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_filter_functions_handler(
|
||||
@@ -670,32 +679,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
body, flags = await chat_completion_tools_handler(
|
||||
body, user, models, extra_params
|
||||
)
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_files_handler(body, user)
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(contexts) > 0:
|
||||
if len(sources) > 0:
|
||||
context_string = ""
|
||||
for context_idx, context in enumerate(contexts):
|
||||
print(context)
|
||||
source_id = citations[context_idx].get("source", {}).get("name", "")
|
||||
for source_idx, source in enumerate(sources):
|
||||
source_id = source.get("source", {}).get("name", "")
|
||||
|
||||
print(f"\n\n\n\n{source_id}\n\n\n\n")
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{source_id}</source_id><source_context>{context}</source_context></source>\n"
|
||||
else:
|
||||
context_string += (
|
||||
f"<source><source_context>{context}</source_context></source>\n"
|
||||
)
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
metadata = source.get("metadata")
|
||||
|
||||
if metadata:
|
||||
doc_source_id = metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
else:
|
||||
# If there is no source_id, then do not include the source_id tag
|
||||
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(body["messages"])
|
||||
@@ -728,8 +739,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
# If there are citations, add them to the data_items
|
||||
if len(citations) > 0:
|
||||
data_items.append({"citations": citations})
|
||||
sources = [
|
||||
source for source in sources if source.get("source", {}).get("name", "")
|
||||
]
|
||||
if len(sources) > 0:
|
||||
data_items.append({"sources": sources})
|
||||
|
||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
|
||||
Reference in New Issue
Block a user