mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'upstream-dev' into dev
This commit is contained in:
@@ -15,6 +15,9 @@ from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, sta
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.apps.webui.models.knowledge import Knowledges
|
||||
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
|
||||
# Document loaders
|
||||
@@ -47,6 +50,8 @@ from open_webui.apps.retrieval.utils import (
|
||||
from open_webui.apps.webui.models.files import Files
|
||||
from open_webui.config import (
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
TIKTOKEN_ENCODING_NAME,
|
||||
RAG_TEXT_SPLITTER,
|
||||
CHUNK_OVERLAP,
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
@@ -102,7 +107,7 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
YoutubeLoader,
|
||||
)
|
||||
@@ -129,6 +134,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
|
||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||
|
||||
app.state.config.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
@@ -171,9 +179,9 @@ def update_embedding_model(
|
||||
auto_update: bool = False,
|
||||
):
|
||||
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
|
||||
import sentence_transformers
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
app.state.sentence_transformer_ef = SentenceTransformer(
|
||||
get_model_path(embedding_model, auto_update),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
@@ -384,18 +392,19 @@ async def get_rag_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"file": {
|
||||
"max_size": app.state.config.FILE_MAX_SIZE,
|
||||
"max_count": app.state.config.FILE_MAX_COUNT,
|
||||
},
|
||||
"content_extraction": {
|
||||
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": app.state.config.TIKA_SERVER_URL,
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": app.state.config.TEXT_SPLITTER,
|
||||
"chunk_size": app.state.config.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
||||
},
|
||||
"file": {
|
||||
"max_size": app.state.config.FILE_MAX_SIZE,
|
||||
"max_count": app.state.config.FILE_MAX_COUNT,
|
||||
},
|
||||
"youtube": {
|
||||
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
@@ -434,6 +443,7 @@ class ContentExtractionConfig(BaseModel):
|
||||
|
||||
|
||||
class ChunkParamUpdateForm(BaseModel):
|
||||
text_splitter: Optional[str] = None
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
|
||||
@@ -493,6 +503,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
|
||||
|
||||
if form_data.chunk is not None:
|
||||
app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
||||
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
|
||||
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
|
||||
|
||||
@@ -539,6 +550,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
"tika_server_url": app.state.config.TIKA_SERVER_URL,
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": app.state.config.TEXT_SPLITTER,
|
||||
"chunk_size": app.state.config.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
||||
},
|
||||
@@ -599,11 +611,10 @@ class QuerySettingsForm(BaseModel):
|
||||
async def update_query_settings(
|
||||
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.RAG_TEMPLATE = (
|
||||
form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE
|
||||
)
|
||||
app.state.config.RAG_TEMPLATE = form_data.template
|
||||
app.state.config.TOP_K = form_data.k if form_data.k else 4
|
||||
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
||||
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
||||
form_data.hybrid if form_data.hybrid else False
|
||||
)
|
||||
@@ -648,18 +659,41 @@ def save_docs_to_vector_db(
|
||||
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
||||
|
||||
if split:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
if app.state.config.TEXT_SPLITTER in ["", "character"]:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
elif app.state.config.TEXT_SPLITTER == "token":
|
||||
text_splitter = TokenTextSplitter(
|
||||
encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
|
||||
chunk_size=app.state.config.CHUNK_SIZE,
|
||||
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
|
||||
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
if len(docs) == 0:
|
||||
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs]
|
||||
metadatas = [
|
||||
{
|
||||
**doc.metadata,
|
||||
**(metadata if metadata else {}),
|
||||
"embedding_config": json.dumps(
|
||||
{
|
||||
"engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
),
|
||||
}
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
# ChromaDB does not like datetime formats
|
||||
# for meta-data so convert them to string.
|
||||
@@ -1255,6 +1289,7 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
|
||||
@app.post("/reset/db")
|
||||
def reset_vector_db(user=Depends(get_admin_user)):
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
Knowledges.delete_all_knowledge()
|
||||
|
||||
|
||||
@app.post("/reset/uploads")
|
||||
@@ -1277,28 +1312,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
||||
print(f"The directory {folder} does not exist")
|
||||
except Exception as e:
|
||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@app.post("/reset")
|
||||
def reset(user=Depends(get_admin_user)) -> bool:
|
||||
folder = f"{UPLOAD_DIR}"
|
||||
for filename in os.listdir(folder):
|
||||
file_path = os.path.join(folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
|
||||
|
||||
try:
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.config import DEFAULT_RAG_TEMPLATE
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -239,8 +240,13 @@ def query_collection_with_hybrid_search(
|
||||
|
||||
|
||||
def rag_template(template: str, context: str, query: str):
|
||||
count = template.count("[context]")
|
||||
assert "[context]" in template, "RAG template does not contain '[context]'"
|
||||
if template == "":
|
||||
template = DEFAULT_RAG_TEMPLATE
|
||||
|
||||
if "[context]" not in template and "{{CONTEXT}}" not in template:
|
||||
log.debug(
|
||||
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
|
||||
)
|
||||
|
||||
if "<context>" in context and "</context>" in context:
|
||||
log.debug(
|
||||
@@ -249,14 +255,25 @@ def rag_template(template: str, context: str, query: str):
|
||||
"nothing, or the user might be trying to hack something."
|
||||
)
|
||||
|
||||
query_placeholders = []
|
||||
if "[query]" in context:
|
||||
query_placeholder = f"[query-{str(uuid.uuid4())}]"
|
||||
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
||||
template = template.replace("[query]", query_placeholder)
|
||||
template = template.replace("[context]", context)
|
||||
query_placeholders.append(query_placeholder)
|
||||
|
||||
if "{{QUERY}}" in context:
|
||||
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
||||
template = template.replace("{{QUERY}}", query_placeholder)
|
||||
query_placeholders.append(query_placeholder)
|
||||
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("{{CONTEXT}}", context)
|
||||
template = template.replace("[query]", query)
|
||||
template = template.replace("{{QUERY}}", query)
|
||||
|
||||
for query_placeholder in query_placeholders:
|
||||
template = template.replace(query_placeholder, query)
|
||||
else:
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("[query]", query)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@@ -375,8 +392,21 @@ def get_rag_context(
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
file_names = list(
|
||||
set(
|
||||
[
|
||||
metadata["name"]
|
||||
for metadata in context["metadatas"][0]
|
||||
if metadata is not None and "name" in metadata
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
"\n\n".join(
|
||||
(", ".join(file_names) + ":\n\n")
|
||||
if file_names
|
||||
else ""
|
||||
+ "\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
)
|
||||
)
|
||||
@@ -393,6 +423,7 @@ def get_rag_context(
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
print(contexts, citations)
|
||||
return contexts, citations
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user