Merge branch 'dev' of https://github.com/open-webui/open-webui into Dev-Individual-RAG-Config

This commit is contained in:
weberm1
2025-06-09 18:11:40 +02:00
88 changed files with 2048 additions and 647 deletions

View File

@@ -5,6 +5,7 @@ import os
import pkgutil
import sys
import shutil
from uuid import uuid4
from pathlib import Path
import markdown
@@ -130,6 +131,7 @@ else:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
# Function to parse each section

View File

@@ -8,6 +8,8 @@ import shutil
import sys
import time
import random
from uuid import uuid4
from contextlib import asynccontextmanager
from urllib.parse import urlencode, parse_qs, urlparse
@@ -19,6 +21,7 @@ from aiocache import cached
import aiohttp
import anyio.to_thread
import requests
from redis import Redis
from fastapi import (
@@ -231,6 +234,9 @@ from open_webui.config import (
DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION,
DOCLING_PICTURE_DESCRIPTION_MODE,
DOCLING_PICTURE_DESCRIPTION_LOCAL,
DOCLING_PICTURE_DESCRIPTION_API,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY,
@@ -393,6 +399,7 @@ from open_webui.env import (
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
INSTANCE_ID,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
@@ -434,8 +441,10 @@ from open_webui.utils.auth import (
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import (
redis_task_command_listener,
list_task_ids_by_chat_id,
stop_task,
list_tasks,
@@ -487,7 +496,9 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.instance_id = INSTANCE_ID
start_logger()
if RESET_CONFIG_ON_START:
reset_config()
@@ -499,6 +510,19 @@ async def lifespan(app: FastAPI):
log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies()
app.state.redis = get_redis_connection(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
),
async_mode=True,
)
if app.state.redis is not None:
app.state.redis_task_command_listener = asyncio.create_task(
redis_task_command_listener(app)
)
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = THREAD_POOL_SIZE
@@ -507,6 +531,9 @@ async def lifespan(app: FastAPI):
yield
if hasattr(app.state, "redis_task_command_listener"):
app.state.redis_task_command_listener.cancel()
app = FastAPI(
title="Open WebUI",
@@ -518,10 +545,12 @@ app = FastAPI(
oauth_manager = OAuthManager(app)
app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
)
app.state.redis = None
app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None
@@ -706,6 +735,9 @@ app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
@@ -1389,26 +1421,30 @@ async def chat_action(
@app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
async def stop_task_endpoint(
request: Request, task_id: str, user=Depends(get_verified_user)
):
try:
result = await stop_task(task_id)
result = await stop_task(request, task_id)
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()}
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
return {"tasks": await list_tasks(request)}
@app.get("/api/tasks/chat/{chat_id}")
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
async def list_tasks_by_chat_id_endpoint(
request: Request, chat_id: str, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
task_ids = list_task_ids_by_chat_id(chat_id)
task_ids = await list_task_ids_by_chat_id(request, chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}

View File

@@ -245,8 +245,9 @@ class KnowledgeTable:
from sqlalchemy import func
with get_db() as db:
if db.bind.dialect.name == "sqlite":
func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model,
query = db.query(Knowledge).filter(
func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model
)
elif db.bind.dialect.name == "postgresql":
query = db.query(Knowledge).filter(
Knowledge.rag_config.op("->>")(model_type) == model,

View File

@@ -2,6 +2,7 @@ import requests
import logging
import ftfy
import sys
import json
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
@@ -154,6 +155,24 @@ class DoclingLoader:
"do_picture_description"
)
picture_description_mode = self.params.get(
"picture_description_mode", ""
).lower()
if picture_description_mode == "local" and self.params.get(
"picture_description_local", {}
):
params["picture_description_local"] = self.params.get(
"picture_description_local", {}
)
elif picture_description_mode == "api" and self.params.get(
"picture_description_api", {}
):
params["picture_description_api"] = self.params.get(
"picture_description_api", {}
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
@@ -281,17 +300,20 @@ class Loader:
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
# Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {})
if not isinstance(params, dict):
try:
params = json.loads(params)
except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
params = {}
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_description": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
params=params,
)
elif (
self.engine == "document_intelligence"

View File

@@ -1,4 +1,5 @@
import logging
from xml.etree.ElementTree import ParseError
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
@@ -93,7 +94,6 @@ class YoutubeLoader:
"http": self.proxy_url,
"https": self.proxy_url,
}
# Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else:
youtube_proxies = None
@@ -110,11 +110,37 @@ class YoutubeLoader:
for lang in self.language:
try:
transcript = transcript_list.find_transcript([lang])
if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'")
try:
transcript = transcript_list.find_manually_created_transcript(
[lang]
)
log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound:
log.debug(
f"No manual transcript found for language '{lang}', using generated"
)
pass
log.debug(f"Found transcript for language '{lang}'")
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
try:
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
except ParseError:
log.debug(f"Empty or invalid transcript for language '{lang}'")
continue
if not transcript_pieces:
log.debug(f"Empty transcript for language '{lang}'")
continue
transcript_text = " ".join(
map(
lambda transcript_piece: transcript_piece.text.strip(" "),
lambda transcript_piece: (
transcript_piece.text.strip(" ")
if hasattr(transcript_piece, "text")
else ""
),
transcript_pieces,
)
)
@@ -131,6 +157,4 @@ class YoutubeLoader:
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))

View File

@@ -1,12 +1,16 @@
from typing import Optional, List, Dict, Any
import logging
import json
from sqlalchemy import (
func,
literal,
cast,
column,
create_engine,
Column,
Integer,
MetaData,
LargeBinary,
select,
text,
Text,
@@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.config import (
PGVECTOR_DB_URL,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
if PGVECTOR_PGCRYPTO:
text = Column(LargeBinary, nullable=True)
vmetadata = Column(LargeBinary, nullable=True)
else:
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient(VectorDBBase):
@@ -147,44 +169,39 @@ class PgvectorClient(VectorDBBase):
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
# Use raw SQL for BYTEA/pgcrypto
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata::text, :key)
)
ON CONFLICT (id) DO NOTHING
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata": json.dumps(item["metadata"]),
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
else:
self.session.commit()
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
@@ -192,11 +209,78 @@ class PgvectorClient(VectorDBBase):
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata::text, :key)
)
ON CONFLICT (id) DO UPDATE SET
vector = EXCLUDED.vector,
collection_name = EXCLUDED.collection_name,
text = EXCLUDED.text,
vmetadata = EXCLUDED.vmetadata
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata": json.dumps(item["metadata"]),
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during upsert: {e}")
@@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
.alias("query_vectors")
)
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
# Build the lateral subquery for each query vector
subq = (
select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
@@ -299,17 +399,43 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if PGVECTOR_PGCRYPTO:
# Build where clause for vmetadata filter
where_clauses = [DocumentChunk.collection_name == collection_name]
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
if limit is not None:
query = query.limit(limit)
if limit is not None:
query = query.limit(limit)
results = query.all()
results = query.all()
if not results:
return None
@@ -331,20 +457,38 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
if PGVECTOR_PGCRYPTO:
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
ids = [[row.id for row in results]]
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
results = query.all()
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
if not results:
return None
results = query.all()
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
@@ -358,17 +502,33 @@ class PgvectorClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
if PGVECTOR_PGCRYPTO:
wheres = [DocumentChunk.collection_name == collection_name]
if ids:
wheres.append(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
wheres.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:

View File

@@ -420,7 +420,7 @@ def load_b64_image_data(b64_str):
try:
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
mime_type = header.split(";")[0].lstrip("data:")
img_data = base64.b64decode(encoded)
else:
mime_type = "image/png"
@@ -428,7 +428,7 @@ def load_b64_image_data(b64_str):
return img_data, mime_type
except Exception as e:
log.exception(f"Error loading image data: {e}")
return None
return None, None
def load_url_image_data(url, headers=None):

View File

@@ -600,6 +600,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
"DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE),
"DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG),
"DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION),
"DOCLING_PICTURE_DESCRIPTION_MODE": rag_config.get("DOCLING_PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE),
"DOCLING_PICTURE_DESCRIPTION_LOCAL": rag_config.get("DOCLING_PICTURE_DESCRIPTION_LOCAL", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL),
"DOCLING_PICTURE_DESCRIPTION_API": rag_config.get("DOCLING_PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API),
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
"MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY),
@@ -766,6 +769,9 @@ class ConfigForm(BaseModel):
DOCLING_OCR_ENGINE: Optional[str] = None
DOCLING_OCR_LANG: Optional[str] = None
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
MISTRAL_OCR_API_KEY: Optional[str] = None
@@ -1050,6 +1056,22 @@ async def update_rag_config(
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
)
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
form_data.DOCLING_PICTURE_DESCRIPTION_MODE
if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
)
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
)
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
form_data.DOCLING_PICTURE_DESCRIPTION_API
if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
)
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
@@ -1307,6 +1329,9 @@ async def update_rag_config(
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -1667,6 +1692,15 @@ def process_file(
docling_do_picture_description=rag_config.get(
"DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
)
picture_description_mode = rag_config.get(
"PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
)
picture_description_local = rag_config.get(
"PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
)
picture_description_api = rag_config.get(
"PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
)
pdf_extract_images = rag_config.get(
"PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES
)
@@ -1757,9 +1791,14 @@ def process_file(
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
TIKA_SERVER_URL=tika_server_url,
DOCLING_SERVER_URL=docling_server_url,
DOCLING_OCR_ENGINE=docling_ocr_engine,
DOCLING_OCR_LANG=docling_ocr_lang,
DOCLING_DO_PICTURE_DESCRIPTION=docling_do_picture_description,
DOCLING_PARAMS={
"ocr_engine": docling_ocr_engine,
"ocr_lang": docling_ocr_lang,
"do_picture_description": docling_do_picture_description,
"picture_description_mode": picture_description_mode,
"picture_description_local": picture_description_local,
"picture_description_api": picture_description_api,
},
PDF_EXTRACT_IMAGES=pdf_extract_images,
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,

View File

@@ -33,7 +33,7 @@ class CodeForm(BaseModel):
@router.post("/code/format")
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
async def format_code(form_data: CodeForm, user=Depends(get_admin_user)):
try:
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}

View File

@@ -2,16 +2,87 @@
import asyncio
from typing import Dict
from uuid import uuid4
import json
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
def cleanup_task(task_id: str, id=None):
REDIS_TASKS_KEY = "open-webui:tasks"
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
def is_redis(request: Request) -> bool:
# Called everywhere a request is available to check Redis
return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
async def redis_task_command_listener(app):
redis: Redis = app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
command = json.loads(message["data"])
if command.get("action") == "stop":
task_id = command.get("task_id")
local_task = tasks.get(task_id)
if local_task:
local_task.cancel()
except Exception as e:
print(f"Error handling distributed task command: {e}")
### ------------------------------
### REDIS-ENABLED HANDLERS
### ------------------------------
async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline()
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
if chat_id:
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
await pipe.execute()
async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id)
if chat_id:
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
await pipe.execute()
async def redis_list_tasks(redis: Redis) -> List[str]:
return list(await redis.hkeys(REDIS_TASKS_KEY))
async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
async def redis_send_command(redis: Redis, command: dict):
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
async def cleanup_task(request, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
if is_redis(request):
await redis_cleanup_task(request.app.state.redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary
@@ -21,7 +92,7 @@ def cleanup_task(task_id: str, id=None):
chat_tasks.pop(id, None)
def create_task(coroutine, id=None):
async def create_task(request, coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -29,7 +100,9 @@ def create_task(coroutine, id=None):
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id, id))
task.add_done_callback(
lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
)
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
@@ -38,34 +111,46 @@ def create_task(coroutine, id=None):
else:
chat_tasks[id] = [task_id]
if is_redis(request):
await redis_save_task(request.app.state.redis, task_id, id)
return task_id, task
def get_task(task_id: str):
"""
Retrieve a task by its task ID.
"""
return tasks.get(task_id)
def list_tasks():
async def list_tasks(request):
"""
List all currently active task IDs.
"""
if is_redis(request):
return await redis_list_tasks(request.app.state.redis)
return list(tasks.keys())
def list_task_ids_by_chat_id(id):
async def list_task_ids_by_chat_id(request, id):
"""
List all tasks associated with a specific ID.
"""
if is_redis(request):
return await redis_list_chat_tasks(request.app.state.redis, id)
return chat_tasks.get(id, [])
async def stop_task(task_id: str):
async def stop_task(request, task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
if is_redis(request):
# PUBSUB: All instances check if they have this task, and stop if so.
await redis_send_command(
request.app.state.redis,
{
"action": "stop",
"task_id": task_id,
},
)
# Optionally check if task_id still in Redis a few moments later for feedback?
return {"status": True, "message": f"Stop signal sent for {task_id}"}
task = tasks.get(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")

View File

@@ -23,6 +23,7 @@ from open_webui.env import (
TRUSTED_SIGNATURE_KEY,
STATIC_DIR,
SRC_LOG_LEVELS,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
)
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
@@ -157,6 +158,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
def get_current_user(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
@@ -225,6 +227,19 @@ def get_current_user(
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER)
if trusted_email and user.email != trusted_email:
# Delete the token cookie
response.delete_cookie("token")
# Delete OAuth token if present
if request.cookies.get("oauth_id_token"):
response.delete_cookie("oauth_id_token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User mismatch. Please sign in again.",
)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:

View File

@@ -37,7 +37,12 @@ from open_webui.routers.tasks import (
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.routers.images import (
load_b64_image_data,
image_generations,
GenerateImageForm,
upload_image,
)
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
@@ -2278,28 +2283,21 @@ async def process_chat_response(
stdoutLines = stdout.split("\n")
for idx, line in enumerate(stdoutLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
image_url = ""
# Extract base64 image data from the line
image_data, content_type = (
load_b64_image_data(line)
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
if image_data is not None:
image_url = upload_image(
request,
image_data,
content_type,
metadata,
user,
)
stdoutLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
f"![Output Image]({image_url})"
)
output["stdout"] = "\n".join(stdoutLines)
@@ -2310,30 +2308,22 @@ async def process_chat_response(
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
image_url = ""
# Extract base64 image data from the line
image_data, content_type = (
load_b64_image_data(line)
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
if image_data is not None:
image_url = upload_image(
request,
image_data,
content_type,
metadata,
user,
)
resultLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
f"![Output Image]({image_url})"
)
output["result"] = "\n".join(resultLines)
except Exception as e:
output = str(e)
@@ -2442,8 +2432,8 @@ async def process_chat_response(
await response.background()
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(
post_response_handler(response, events), id=metadata["chat_id"]
task_id, _ = await create_task(
request, post_response_handler(response, events), id=metadata["chat_id"]
)
return {"status": True, "task_id": task_id}

View File

@@ -1,7 +1,6 @@
import socketio
import redis
from redis import asyncio as aioredis
from urllib.parse import urlparse
from typing import Optional
def parse_redis_service_url(redis_url):
@@ -18,23 +17,46 @@ def parse_redis_service_url(redis_url):
}
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
def get_redis_connection(
redis_url, redis_sentinels, async_mode=False, decode_responses=True
):
if async_mode:
import redis.asyncio as redis
# Get a master connection from Sentinel
return sentinel.master_for(redis_config["service"])
# If using sentinel in async mode
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
return sentinel.master_for(redis_config["service"])
elif redis_url:
return redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None
else:
# Standard Redis connection
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
import redis
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
return sentinel.master_for(redis_config["service"])
elif redis_url:
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):

View File

@@ -14,7 +14,11 @@ if [[ "${WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
python -c "import nltk; nltk.download('punkt_tab')"
fi
KEY_FILE=.webui_secret_key
if [ -n "${WEBUI_SECRET_KEY_FILE}" ]; then
KEY_FILE="${WEBUI_SECRET_KEY_FILE}"
else
KEY_FILE=".webui_secret_key"
fi
PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"

View File

@@ -18,6 +18,10 @@ IF /I "%WEB_LOADER_ENGINE%" == "playwright" (
)
SET "KEY_FILE=.webui_secret_key"
IF NOT "%WEBUI_SECRET_KEY_FILE%" == "" (
SET "KEY_FILE=%WEBUI_SECRET_KEY_FILE%"
)
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"