Merge branch 'dev' into feat/backend-web-search

This commit is contained in:
Timothy Jaeryang Baek
2024-05-06 16:26:44 -07:00
committed by GitHub
56 changed files with 2166 additions and 1071 deletions

View File

@@ -80,6 +80,7 @@ from config import (
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
RAG_RERANKING_MODEL,
PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE,
@@ -91,7 +92,7 @@ from config import (
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
)
from constants import ERROR_MESSAGES
@@ -105,6 +106,9 @@ app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
@@ -114,6 +118,7 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
@@ -313,6 +318,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
}
@@ -322,15 +328,34 @@ class ChunkParamUpdateForm(BaseModel):
class ConfigUpdateForm(BaseModel):
pdf_extract_images: bool
chunk: ChunkParamUpdateForm
pdf_extract_images: Optional[bool] = None
chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
app.state.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
app.state.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
)
app.state.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
return {
"status": True,
@@ -339,6 +364,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
}
@@ -490,7 +516,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(form_data.url)
loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
data = loader.load()
collection_name = form_data.collection_name
@@ -510,12 +538,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
detail=ERROR_MESSAGES.DEFAULT(e),
)
def get_web_loader(url: Union[str, Sequence[str]]):
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
return WebBaseLoader(url, verify_ssl=verify_ssl)
def validate_url(url: Union[str, Sequence[str]]):

View File

@@ -287,14 +287,14 @@ def rag_messages(
for doc in docs:
context = None
collection = doc.get("collection_name")
if collection:
collection = [collection]
else:
collection = doc.get("collection_names", [])
collection_names = (
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
)
collection = set(collection).difference(extracted_collections)
if not collection:
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted")
continue
@@ -304,11 +304,7 @@ def rag_messages(
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
@@ -317,11 +313,7 @@ def rag_messages(
)
else:
context = query_collection(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
@@ -331,18 +323,31 @@ def rag_messages(
context = None
if context:
relevant_contexts.append(context)
relevant_contexts.append({**context, "source": doc})
extracted_collections.extend(collection)
extracted_collections.extend(collection_names)
context_string = ""
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
items = [item for item in context["documents"][0] if item is not None]
context_string += "\n\n".join(items)
context_string += "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
if "metadatas" in context:
citations.append(
{
"source": context["source"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
except Exception as e:
log.exception(e)
context_string = context_string.strip()
ra_content = rag_template(
@@ -371,7 +376,7 @@ def rag_messages(
messages[last_user_message_idx] = new_user_message
return messages
return messages, citations
def get_model_path(model: str, update_model: bool = False):

View File

@@ -18,6 +18,18 @@ from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
print("dotenv not installed, skipping...")
####################################
# LOGGING
####################################
@@ -59,16 +71,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
log.warning("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
@@ -454,6 +456,11 @@ ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
@@ -531,7 +538,9 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")

View File

@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
@@ -102,6 +102,8 @@ origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
@@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
data = {**data}
data["messages"] = rag_messages(
data["messages"], citations = rag_messages(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE,
@@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
)
del data["docs"]
log.debug(f"data['messages']: {data['messages']}")
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
)
modified_body_bytes = json.dumps(data).encode("utf-8")
@@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
]
response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async for data in original_generator:
yield data
app.add_middleware(RAGMiddleware)