Merge branch 'upstream-dev' into dev

This commit is contained in:
Jannik Streidl
2024-10-14 09:50:40 +02:00
39 changed files with 1235 additions and 469 deletions

View File

@@ -18,7 +18,10 @@ from open_webui.config import (
OPENAI_API_KEYS,
AppConfig,
)
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@@ -179,7 +182,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=3)
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:

View File

@@ -15,6 +15,9 @@ from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, sta
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.apps.webui.models.knowledge import Knowledges
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
# Document loaders
@@ -47,6 +50,8 @@ from open_webui.apps.retrieval.utils import (
from open_webui.apps.webui.models.files import Files
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
TIKTOKEN_ENCODING_NAME,
RAG_TEXT_SPLITTER,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
@@ -102,7 +107,7 @@ from open_webui.utils.misc import (
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_community.document_loaders import (
YoutubeLoader,
)
@@ -129,6 +134,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
@@ -171,9 +179,9 @@ def update_embedding_model(
auto_update: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
from sentence_transformers import SentenceTransformer
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
app.state.sentence_transformer_ef = SentenceTransformer(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@@ -384,18 +392,19 @@ async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"file": {
"max_size": app.state.config.FILE_MAX_SIZE,
"max_count": app.state.config.FILE_MAX_COUNT,
},
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"text_splitter": app.state.config.TEXT_SPLITTER,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"file": {
"max_size": app.state.config.FILE_MAX_SIZE,
"max_count": app.state.config.FILE_MAX_COUNT,
},
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
@@ -434,6 +443,7 @@ class ContentExtractionConfig(BaseModel):
class ChunkParamUpdateForm(BaseModel):
text_splitter: Optional[str] = None
chunk_size: int
chunk_overlap: int
@@ -493,6 +503,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None:
app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
@@ -539,6 +550,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"text_splitter": app.state.config.TEXT_SPLITTER,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
@@ -599,11 +611,10 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.config.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else DEFAULT_RAG_TEMPLATE
)
app.state.config.RAG_TEMPLATE = form_data.template
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False
)
@@ -648,18 +659,41 @@ def save_docs_to_vector_db(
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
if app.state.config.TEXT_SPLITTER in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
elif app.state.config.TEXT_SPLITTER == "token":
text_splitter = TokenTextSplitter(
encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
docs = text_splitter.split_documents(docs)
if len(docs) == 0:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
texts = [doc.page_content for doc in docs]
metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs]
metadatas = [
{
**doc.metadata,
**(metadata if metadata else {}),
"embedding_config": json.dumps(
{
"engine": app.state.config.RAG_EMBEDDING_ENGINE,
"model": app.state.config.RAG_EMBEDDING_MODEL,
}
),
}
for doc in docs
]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
@@ -1255,6 +1289,7 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
@app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
VECTOR_DB_CLIENT.reset()
Knowledges.delete_all_knowledge()
@app.post("/reset/uploads")
@@ -1277,28 +1312,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
print(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
return True
@app.post("/reset")
def reset(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try:
VECTOR_DB_CLIENT.reset()
except Exception as e:
log.exception(e)
return True

View File

@@ -19,6 +19,7 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__)
@@ -239,8 +240,13 @@ def query_collection_with_hybrid_search(
def rag_template(template: str, context: str, query: str):
count = template.count("[context]")
assert "[context]" in template, "RAG template does not contain '[context]'"
if template == "":
template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
@@ -249,14 +255,25 @@ def rag_template(template: str, context: str, query: str):
"nothing, or the user might be trying to hack something."
)
query_placeholders = []
if "[query]" in context:
query_placeholder = f"[query-{str(uuid.uuid4())}]"
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder)
template = template.replace("[context]", context)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
@@ -375,8 +392,21 @@ def get_rag_context(
for context in relevant_contexts:
try:
if "documents" in context:
file_names = list(
set(
[
metadata["name"]
for metadata in context["metadatas"][0]
if metadata is not None and "name" in metadata
]
)
)
contexts.append(
"\n\n".join(
(", ".join(file_names) + ":\n\n")
if file_names
else ""
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
@@ -393,6 +423,7 @@ def get_rag_context(
except Exception as e:
log.exception(e)
print(contexts, citations)
return contexts, citations

View File

@@ -61,6 +61,9 @@ class ChatModel(BaseModel):
class ChatForm(BaseModel):
chat: dict
class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]
class ChatTitleForm(BaseModel):
title: str

View File

@@ -154,5 +154,15 @@ class KnowledgeTable:
except Exception:
return False
def delete_all_knowledge(self) -> bool:
with get_db() as db:
try:
db.query(Knowledge).delete()
db.commit()
return True
except Exception:
return False
Knowledges = KnowledgeTable()

View File

@@ -8,7 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, JSON
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -19,11 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tag(Base):
__tablename__ = "tag"
id = Column(String, primary_key=True)
id = Column(String)
name = Column(String)
user_id = Column(String)
meta = Column(JSON, nullable=True)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
class TagModel(BaseModel):
id: str
@@ -57,7 +60,8 @@ class TagTable:
return TagModel.model_validate(result)
else:
return None
except Exception:
except Exception as e:
print(e)
return None
def get_tag_by_name_and_user_id(
@@ -78,11 +82,15 @@ class TagTable:
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
]
def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str
) -> list[TagModel]:
with get_db() as db:
return [
TagModel.model_validate(tag)
for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
for tag in (
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
)
]
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:

View File

@@ -465,7 +465,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -494,7 +494,7 @@ async def add_tag_by_id_and_tag_name(
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -519,7 +519,7 @@ async def delete_tag_by_id_and_tag_name(
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -543,7 +543,7 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

View File

@@ -1,16 +1,14 @@
import site
from pathlib import Path
import black
import markdown
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.env import FONTS_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fpdf import FPDF
from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.utils import get_admin_user
router = APIRouter()
@@ -56,58 +54,19 @@ class ChatForm(BaseModel):
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatForm,
form_data: ChatTitleMessagesForm,
):
global FONTS_DIR
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
pdf = FPDF()
pdf.add_page()
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf")
pdf.set_font("NotoSans", size=12)
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"])
pdf.set_auto_page_break(auto=True, margin=15)
# Adjust the effective page width for multi_cell
effective_page_width = (
pdf.w - 2 * pdf.l_margin - 10
) # Subtracted an additional 10 for extra padding
# Add chat messages
for message in form_data.messages:
role = message["role"]
content = message["content"]
pdf.set_font("NotoSans", "B", size=14) # Bold for the role
pdf.multi_cell(effective_page_width, 10, f"{role.upper()}", 0, "L")
pdf.ln(1) # Extra space between messages
pdf.set_font("NotoSans", size=10) # Regular for content
pdf.multi_cell(effective_page_width, 6, content, 0, "L")
pdf.ln(1.5) # Extra space between messages
# Save the pdf with name .pdf
pdf_bytes = pdf.output()
return Response(
content=bytes(pdf_bytes),
media_type="application/pdf",
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
return Response(
content=pdf_bytes,
media_type="application/pdf",
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=str(e))
@router.get("/db/download")