mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into feat/backend-web-search
This commit is contained in:
@@ -80,6 +80,7 @@ from config import (
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
RAG_RERANKING_MODEL,
|
||||
PDF_EXTRACT_IMAGES,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
@@ -91,7 +92,7 @@ from config import (
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
RAG_TEMPLATE,
|
||||
ENABLE_LOCAL_WEB_FETCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -105,6 +106,9 @@ app.state.TOP_K = RAG_TOP_K
|
||||
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
|
||||
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
|
||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
@@ -114,6 +118,7 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
|
||||
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||
|
||||
@@ -313,6 +318,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
|
||||
"chunk_size": app.state.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
||||
},
|
||||
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
}
|
||||
|
||||
|
||||
@@ -322,15 +328,34 @@ class ChunkParamUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class ConfigUpdateForm(BaseModel):
|
||||
pdf_extract_images: bool
|
||||
chunk: ChunkParamUpdateForm
|
||||
pdf_extract_images: Optional[bool] = None
|
||||
chunk: Optional[ChunkParamUpdateForm] = None
|
||||
web_loader_ssl_verification: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
||||
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
|
||||
app.state.CHUNK_SIZE = form_data.chunk.chunk_size
|
||||
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
|
||||
app.state.PDF_EXTRACT_IMAGES = (
|
||||
form_data.pdf_extract_images
|
||||
if form_data.pdf_extract_images != None
|
||||
else app.state.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
|
||||
app.state.CHUNK_SIZE = (
|
||||
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
|
||||
)
|
||||
|
||||
app.state.CHUNK_OVERLAP = (
|
||||
form_data.chunk.chunk_overlap
|
||||
if form_data.chunk != None
|
||||
else app.state.CHUNK_OVERLAP
|
||||
)
|
||||
|
||||
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
form_data.web_loader_ssl_verification
|
||||
if form_data.web_loader_ssl_verification != None
|
||||
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -339,6 +364,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
"chunk_size": app.state.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
||||
},
|
||||
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
}
|
||||
|
||||
|
||||
@@ -490,7 +516,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
try:
|
||||
loader = get_web_loader(form_data.url)
|
||||
loader = get_web_loader(
|
||||
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
data = loader.load()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
@@ -510,12 +538,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def get_web_loader(url: Union[str, Sequence[str]]):
|
||||
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
|
||||
# Check if the URL is valid
|
||||
if not validate_url(url):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(url)
|
||||
return WebBaseLoader(url, verify_ssl=verify_ssl)
|
||||
|
||||
|
||||
def validate_url(url: Union[str, Sequence[str]]):
|
||||
|
||||
@@ -287,14 +287,14 @@ def rag_messages(
|
||||
for doc in docs:
|
||||
context = None
|
||||
|
||||
collection = doc.get("collection_name")
|
||||
if collection:
|
||||
collection = [collection]
|
||||
else:
|
||||
collection = doc.get("collection_names", [])
|
||||
collection_names = (
|
||||
doc["collection_names"]
|
||||
if doc["type"] == "collection"
|
||||
else [doc["collection_name"]]
|
||||
)
|
||||
|
||||
collection = set(collection).difference(extracted_collections)
|
||||
if not collection:
|
||||
collection_names = set(collection_names).difference(extracted_collections)
|
||||
if not collection_names:
|
||||
log.debug(f"skipping {doc} as it has already been extracted")
|
||||
continue
|
||||
|
||||
@@ -304,11 +304,7 @@ def rag_messages(
|
||||
else:
|
||||
if hybrid_search:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=(
|
||||
doc["collection_names"]
|
||||
if doc["type"] == "collection"
|
||||
else [doc["collection_name"]]
|
||||
),
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
@@ -317,11 +313,7 @@ def rag_messages(
|
||||
)
|
||||
else:
|
||||
context = query_collection(
|
||||
collection_names=(
|
||||
doc["collection_names"]
|
||||
if doc["type"] == "collection"
|
||||
else [doc["collection_name"]]
|
||||
),
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
@@ -331,18 +323,31 @@ def rag_messages(
|
||||
context = None
|
||||
|
||||
if context:
|
||||
relevant_contexts.append(context)
|
||||
relevant_contexts.append({**context, "source": doc})
|
||||
|
||||
extracted_collections.extend(collection)
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
context_string = ""
|
||||
|
||||
citations = []
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
items = [item for item in context["documents"][0] if item is not None]
|
||||
context_string += "\n\n".join(items)
|
||||
context_string += "\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
)
|
||||
|
||||
if "metadatas" in context:
|
||||
citations.append(
|
||||
{
|
||||
"source": context["source"],
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
context_string = context_string.strip()
|
||||
|
||||
ra_content = rag_template(
|
||||
@@ -371,7 +376,7 @@ def rag_messages(
|
||||
|
||||
messages[last_user_message_idx] = new_user_message
|
||||
|
||||
return messages
|
||||
return messages, citations
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
@@ -18,6 +18,18 @@ from secrets import token_bytes
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
####################################
|
||||
# Load .env file
|
||||
####################################
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
load_dotenv(find_dotenv("../.env"))
|
||||
except ImportError:
|
||||
print("dotenv not installed, skipping...")
|
||||
|
||||
|
||||
####################################
|
||||
# LOGGING
|
||||
####################################
|
||||
@@ -59,16 +71,6 @@ for source in log_sources:
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
####################################
|
||||
# Load .env file
|
||||
####################################
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
load_dotenv(find_dotenv("../.env"))
|
||||
except ImportError:
|
||||
log.warning("dotenv not installed, skipping...")
|
||||
|
||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
@@ -454,6 +456,11 @@ ENABLE_RAG_HYBRID_SEARCH = (
|
||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
|
||||
|
||||
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
|
||||
@@ -531,7 +538,9 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
|
||||
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
||||
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
|
||||
|
||||
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH = (
|
||||
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||
)
|
||||
|
||||
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
|
||||
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from apps.ollama.main import app as ollama_app
|
||||
from apps.openai.main import app as openai_app
|
||||
@@ -102,6 +102,8 @@ origins = ["*"]
|
||||
|
||||
class RAGMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
|
||||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
):
|
||||
@@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
return_citations = data.get("citations", False)
|
||||
if "citations" in data:
|
||||
del data["citations"]
|
||||
|
||||
# Example: Add a new key-value pair or modify existing ones
|
||||
# data["modified"] = True # Example modification
|
||||
if "docs" in data:
|
||||
data = {**data}
|
||||
data["messages"] = rag_messages(
|
||||
data["messages"], citations = rag_messages(
|
||||
docs=data["docs"],
|
||||
messages=data["messages"],
|
||||
template=rag_app.state.RAG_TEMPLATE,
|
||||
@@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
del data["docs"]
|
||||
|
||||
log.debug(f"data['messages']: {data['messages']}")
|
||||
log.debug(
|
||||
f"data['messages']: {data['messages']}, citations: {citations}"
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
|
||||
@@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if return_citations:
|
||||
# Inject the citations into the response
|
||||
if isinstance(response, StreamingResponse):
|
||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if "text/event-stream" in content_type:
|
||||
return StreamingResponse(
|
||||
self.openai_stream_wrapper(response.body_iterator, citations),
|
||||
)
|
||||
if "application/x-ndjson" in content_type:
|
||||
return StreamingResponse(
|
||||
self.ollama_stream_wrapper(response.body_iterator, citations),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def openai_stream_wrapper(self, original_generator, citations):
|
||||
yield f"data: {json.dumps({'citations': citations})}\n\n"
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
async def ollama_stream_wrapper(self, original_generator, citations):
|
||||
yield f"{json.dumps({'citations': citations})}\n"
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
|
||||
app.add_middleware(RAGMiddleware)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user