Merge branch 'open-webui:main' into main

This commit is contained in:
Jarrod Lowe
2025-04-17 15:46:24 +12:00
committed by GitHub
118 changed files with 5187 additions and 3201 deletions

View File

@@ -76,11 +76,11 @@ def serve(
from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run(
open_webui.main.app,
host=host,
port=port,
open_webui.main.app,
host=host,
port=port,
forwarded_allow_ips="*",
workers=UVICORN_WORKERS
workers=UVICORN_WORKERS,
)

View File

@@ -201,7 +201,10 @@ def save_config(config):
T = TypeVar("T")
ENABLE_PERSISTENT_CONFIG = os.environ.get("ENABLE_PERSISTENT_CONFIG", "True").lower() == "true"
ENABLE_PERSISTENT_CONFIG = (
os.environ.get("ENABLE_PERSISTENT_CONFIG", "True").lower() == "true"
)
class PersistentConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T):
@@ -612,10 +615,16 @@ def load_oauth_providers():
"scope": OAUTH_SCOPES.value,
}
if OAUTH_CODE_CHALLENGE_METHOD.value and OAUTH_CODE_CHALLENGE_METHOD.value == "S256":
if (
OAUTH_CODE_CHALLENGE_METHOD.value
and OAUTH_CODE_CHALLENGE_METHOD.value == "S256"
):
client_kwargs["code_challenge_method"] = "S256"
elif OAUTH_CODE_CHALLENGE_METHOD.value:
raise Exception('Code challenge methods other than "%s" not supported. Given: "%s"' % ("S256", OAUTH_CODE_CHALLENGE_METHOD.value))
raise Exception(
'Code challenge methods other than "%s" not supported. Given: "%s"'
% ("S256", OAUTH_CODE_CHALLENGE_METHOD.value)
)
client.register(
name="oidc",
@@ -1053,6 +1062,22 @@ USER_PERMISSIONS_CHAT_EDIT = (
os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_STT = (
os.environ.get("USER_PERMISSIONS_CHAT_STT", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_TTS = (
os.environ.get("USER_PERMISSIONS_CHAT_TTS", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_CALL = (
os.environ.get("USER_PERMISSIONS_CHAT_CALL", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = (
os.environ.get("USER_PERMISSIONS_CHAT_MULTIPLE_MODELS", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_TEMPORARY = (
os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true"
)
@@ -1062,6 +1087,7 @@ USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED = (
== "true"
)
USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS = (
os.environ.get("USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS", "False").lower()
== "true"
@@ -1100,6 +1126,10 @@ DEFAULT_USER_PERMISSIONS = {
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
"delete": USER_PERMISSIONS_CHAT_DELETE,
"edit": USER_PERMISSIONS_CHAT_EDIT,
"stt": USER_PERMISSIONS_CHAT_STT,
"tts": USER_PERMISSIONS_CHAT_TTS,
"call": USER_PERMISSIONS_CHAT_CALL,
"multiple_models": USER_PERMISSIONS_CHAT_MULTIPLE_MODELS,
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
"temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED,
},
@@ -1820,12 +1850,6 @@ RAG_FILE_MAX_SIZE = PersistentConfig(
),
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
"rag.enable_web_loader_ssl_verification",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
)
RAG_EMBEDDING_ENGINE = PersistentConfig(
"RAG_EMBEDDING_ENGINE",
"rag.embedding_engine",
@@ -1990,16 +2014,20 @@ YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
)
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
"ENABLE_RAG_WEB_SEARCH",
####################################
# Web Search (RAG)
####################################
ENABLE_WEB_SEARCH = PersistentConfig(
"ENABLE_WEB_SEARCH",
"rag.web.search.enable",
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
os.getenv("ENABLE_WEB_SEARCH", "False").lower() == "true",
)
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
"RAG_WEB_SEARCH_ENGINE",
WEB_SEARCH_ENGINE = PersistentConfig(
"WEB_SEARCH_ENGINE",
"rag.web.search.engine",
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
os.getenv("WEB_SEARCH_ENGINE", ""),
)
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
@@ -2008,10 +2036,18 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
)
WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
int(os.getenv("WEB_SEARCH_RESULT_COUNT", "3")),
)
# You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"WEB_SEARCH_DOMAIN_FILTER_LIST",
"rag.web.search.domain.filter_list",
[
# "wikipedia.com",
@@ -2020,6 +2056,30 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
],
)
WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
"WEB_SEARCH_CONCURRENT_REQUESTS",
"rag.web.search.concurrent_requests",
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
WEB_LOADER_ENGINE = PersistentConfig(
"WEB_LOADER_ENGINE",
"rag.web.loader.engine",
os.environ.get("WEB_LOADER_ENGINE", ""),
)
ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
"ENABLE_WEB_LOADER_SSL_VERIFICATION",
"rag.web.loader.ssl_verification",
os.environ.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
)
WEB_SEARCH_TRUST_ENV = PersistentConfig(
"WEB_SEARCH_TRUST_ENV",
"rag.web.search.trust_env",
os.getenv("WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
@@ -2087,18 +2147,6 @@ SERPLY_API_KEY = PersistentConfig(
os.getenv("SERPLY_API_KEY", ""),
)
TAVILY_API_KEY = PersistentConfig(
"TAVILY_API_KEY",
"rag.web.search.tavily_api_key",
os.getenv("TAVILY_API_KEY", ""),
)
TAVILY_EXTRACT_DEPTH = PersistentConfig(
"TAVILY_EXTRACT_DEPTH",
"rag.web.search.tavily_extract_depth",
os.getenv("TAVILY_EXTRACT_DEPTH", "basic"),
)
JINA_API_KEY = PersistentConfig(
"JINA_API_KEY",
"rag.web.search.jina_api_key",
@@ -2167,54 +2215,43 @@ SOUGOU_API_SK = PersistentConfig(
os.getenv("SOUGOU_API_SK", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
TAVILY_API_KEY = PersistentConfig(
"TAVILY_API_KEY",
"rag.web.search.tavily_api_key",
os.getenv("TAVILY_API_KEY", ""),
)
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
"rag.web.search.concurrent_requests",
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
TAVILY_EXTRACT_DEPTH = PersistentConfig(
"TAVILY_EXTRACT_DEPTH",
"rag.web.search.tavily_extract_depth",
os.getenv("TAVILY_EXTRACT_DEPTH", "basic"),
)
RAG_WEB_LOADER_ENGINE = PersistentConfig(
"RAG_WEB_LOADER_ENGINE",
"rag.web.loader.engine",
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
)
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
"RAG_WEB_SEARCH_TRUST_ENV",
"rag.web.search.trust_env",
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
)
PLAYWRIGHT_WS_URI = PersistentConfig(
"PLAYWRIGHT_WS_URI",
"rag.web.loader.engine.playwright.ws.uri",
os.environ.get("PLAYWRIGHT_WS_URI", None),
PLAYWRIGHT_WS_URL = PersistentConfig(
"PLAYWRIGHT_WS_URL",
"rag.web.loader.playwright_ws_url",
os.environ.get("PLAYWRIGHT_WS_URL", ""),
)
PLAYWRIGHT_TIMEOUT = PersistentConfig(
"PLAYWRIGHT_TIMEOUT",
"rag.web.loader.engine.playwright.timeout",
int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10")),
"rag.web.loader.playwright_timeout",
int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10000")),
)
FIRECRAWL_API_KEY = PersistentConfig(
"FIRECRAWL_API_KEY",
"firecrawl.api_key",
"rag.web.loader.firecrawl_api_key",
os.environ.get("FIRECRAWL_API_KEY", ""),
)
FIRECRAWL_API_BASE_URL = PersistentConfig(
"FIRECRAWL_API_BASE_URL",
"firecrawl.api_url",
"rag.web.loader.firecrawl_api_url",
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
)
####################################
# Images
####################################
@@ -2467,6 +2504,13 @@ WHISPER_MODEL_AUTO_UPDATE = (
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
WHISPER_VAD_FILTER = PersistentConfig(
"WHISPER_VAD_FILTER",
"audio.stt.whisper_vad_filter",
os.getenv("WHISPER_VAD_FILTER", "False").lower() == "true",
)
# Add Deepgram configuration
DEEPGRAM_API_KEY = PersistentConfig(
"DEEPGRAM_API_KEY",
@@ -2474,6 +2518,7 @@ DEEPGRAM_API_KEY = PersistentConfig(
os.getenv("DEEPGRAM_API_KEY", ""),
)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",

View File

@@ -498,4 +498,4 @@ PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
# PROGRESSIVE WEB APP OPTIONS
####################################
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")

View File

@@ -160,12 +160,13 @@ from open_webui.config import (
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
PLAYWRIGHT_WS_URI,
PLAYWRIGHT_WS_URL,
PLAYWRIGHT_TIMEOUT,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
RAG_WEB_LOADER_ENGINE,
WEB_LOADER_ENGINE,
WHISPER_MODEL,
WHISPER_VAD_FILTER,
DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
@@ -205,12 +206,13 @@ from open_webui.config import (
YOUTUBE_LOADER_LANGUAGE,
YOUTUBE_LOADER_PROXY_URL,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE,
ENABLE_WEB_SEARCH,
WEB_SEARCH_ENGINE,
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_TRUST_ENV,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
WEB_SEARCH_RESULT_COUNT,
WEB_SEARCH_CONCURRENT_REQUESTS,
WEB_SEARCH_TRUST_ENV,
WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
@@ -240,8 +242,7 @@ from open_webui.config import (
ONEDRIVE_CLIENT_ID,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENABLE_WEB_LOADER_SSL_VERIFICATION,
ENABLE_GOOGLE_DRIVE_INTEGRATION,
ENABLE_ONEDRIVE_INTEGRATION,
UPLOAD_DIR,
@@ -373,7 +374,11 @@ from open_webui.utils.auth import (
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
from open_webui.tasks import (
list_task_ids_by_chat_id,
stop_task,
list_tasks,
) # Import from tasks.py
from open_webui.utils.redis import get_sentinels_from_env
@@ -596,9 +601,7 @@ app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
@@ -631,12 +634,16 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.ENABLE_WEB_SEARCH = ENABLE_WEB_SEARCH
app.state.config.WEB_SEARCH_ENGINE = WEB_SEARCH_ENGINE
app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.WEB_SEARCH_RESULT_COUNT = WEB_SEARCH_RESULT_COUNT
app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE
app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
@@ -664,11 +671,8 @@ app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
@@ -788,6 +792,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.config.WHISPER_VAD_FILTER = WHISPER_VAD_FILTER
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY
@@ -1022,14 +1027,19 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
continue
model_tags = [
tag.get("name")
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
]
tags = [tag.get("name") for tag in model.get("tags", [])]
try:
model_tags = [
tag.get("name")
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
]
tags = [tag.get("name") for tag in model.get("tags", [])]
tags = list(set(model_tags + tags))
model["tags"] = [{"name": tag} for tag in tags]
tags = list(set(model_tags + tags))
model["tags"] = [{"name": tag} for tag in tags]
except Exception as e:
log.debug(f"Error processing model tags: {e}")
model["tags"] = []
pass
models.append(model)
@@ -1199,7 +1209,7 @@ async def chat_action(
@app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
try:
result = await stop_task(task_id) # Use the function from tasks.py
result = await stop_task(task_id)
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -1207,7 +1217,19 @@ async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
@app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()} # Use the function from tasks.py
return {"tasks": list_tasks()}
@app.get("/api/tasks/chat/{chat_id}")
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
task_ids = list_task_ids_by_chat_id(chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
##################################
@@ -1263,7 +1285,7 @@ async def get_app_config(request: Request):
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
"enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION,
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,

View File

@@ -63,14 +63,15 @@ class MemoriesTable:
else:
return None
def update_memory_by_id(
def update_memory_by_id_and_user_id(
self,
id: str,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
try:
db.query(Memory).filter_by(id=id).update(
db.query(Memory).filter_by(id=id, user_id=user_id).update(
{"content": content, "updated_at": int(time.time())}
)
db.commit()

View File

@@ -297,7 +297,9 @@ def query_collection_with_hybrid_search(
collection_results = {}
for collection_name in collection_names:
try:
log.debug(f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}")
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
@@ -619,7 +621,9 @@ def generate_openai_batch_embeddings(
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}")
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
@@ -662,7 +666,9 @@ def generate_ollama_batch_embeddings(
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}")
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix

View File

@@ -28,9 +28,9 @@ from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
PLAYWRIGHT_WS_URI,
PLAYWRIGHT_WS_URL,
PLAYWRIGHT_TIMEOUT,
RAG_WEB_LOADER_ENGINE,
WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
TAVILY_API_KEY,
@@ -584,13 +584,6 @@ class SafeWebBaseLoader(WebBaseLoader):
return [document async for document in self.alazy_load()]
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
@@ -608,27 +601,36 @@ def get_web_loader(
"trust_env": trust_env,
}
if RAG_WEB_LOADER_ENGINE.value == "playwright":
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
WebLoaderClass = SafeWebBaseLoader
if WEB_LOADER_ENGINE.value == "playwright":
WebLoaderClass = SafePlaywrightURLLoader
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
if PLAYWRIGHT_WS_URI.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
if PLAYWRIGHT_WS_URL.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
if WEB_LOADER_ENGINE.value == "firecrawl":
WebLoaderClass = SafeFireCrawlLoader
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
if RAG_WEB_LOADER_ENGINE.value == "tavily":
if WEB_LOADER_ENGINE.value == "tavily":
WebLoaderClass = SafeTavilyLoader
web_loader_args["api_key"] = TAVILY_API_KEY.value
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
# Create the appropriate WebLoader based on the configuration
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
web_loader = WebLoaderClass(**web_loader_args)
if WebLoaderClass:
web_loader = WebLoaderClass(**web_loader_args)
log.debug(
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
log.debug(
"Using WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
return web_loader
return web_loader
else:
raise ValueError(
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
)

View File

@@ -330,7 +330,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
@@ -384,7 +384,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
@@ -440,7 +440,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
@@ -497,7 +497,11 @@ def transcribe(request: Request, file_path):
)
model = request.app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
segments, info = model.transcribe(
file_path,
beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
@@ -624,10 +628,7 @@ def transcribe(request: Request, file_path):
elif request.app.state.config.STT_ENGINE == "azure":
# Check file exists and size
if not os.path.exists(file_path):
raise HTTPException(
status_code=400,
detail="Audio file not found"
)
raise HTTPException(status_code=400, detail="Audio file not found")
# Check file size (Azure has a larger limit of 200MB)
file_size = os.path.getsize(file_path)
@@ -643,11 +644,22 @@ def transcribe(request: Request, file_path):
# IF NO LOCALES, USE DEFAULTS
if len(locales) < 2:
locales = ['en-US', 'es-ES', 'es-MX', 'fr-FR', 'hi-IN',
'it-IT','de-DE', 'en-GB', 'en-IN', 'ja-JP',
'ko-KR', 'pt-BR', 'zh-CN']
locales = ','.join(locales)
locales = [
"en-US",
"es-ES",
"es-MX",
"fr-FR",
"hi-IN",
"it-IT",
"de-DE",
"en-GB",
"en-IN",
"ja-JP",
"ko-KR",
"pt-BR",
"zh-CN",
]
locales = ",".join(locales)
if not api_key or not region:
raise HTTPException(
@@ -658,22 +670,26 @@ def transcribe(request: Request, file_path):
r = None
try:
# Prepare the request
data = {'definition': json.dumps({
"locales": locales.split(','),
"diarization": {"maxSpeakers": 3,"enabled": True}
} if locales else {}
)
data = {
"definition": json.dumps(
{
"locales": locales.split(","),
"diarization": {"maxSpeakers": 3, "enabled": True},
}
if locales
else {}
)
}
url = f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
# Use context manager to ensure file is properly closed
with open(file_path, 'rb') as audio_file:
with open(file_path, "rb") as audio_file:
r = requests.post(
url=url,
files={'audio': audio_file},
files={"audio": audio_file},
data=data,
headers={
'Ocp-Apim-Subscription-Key': api_key,
"Ocp-Apim-Subscription-Key": api_key,
},
)
@@ -681,11 +697,11 @@ def transcribe(request: Request, file_path):
response = r.json()
# Extract transcript from response
if not response.get('combinedPhrases'):
if not response.get("combinedPhrases"):
raise ValueError("No transcription found in response")
# Get the full transcript from combinedPhrases
transcript = response['combinedPhrases'][0].get('text', '').strip()
transcript = response["combinedPhrases"][0].get("text", "").strip()
if not transcript:
raise ValueError("Empty transcript in response")
@@ -718,7 +734,7 @@ def transcribe(request: Request, file_path):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, 'status_code', 500) if r else 500,
status_code=getattr(r, "status_code", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)

View File

@@ -231,13 +231,15 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
email = entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]
email = entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"].value # retrive the Attribute value
if not email:
raise HTTPException(400, "User does not have a valid email address.")
elif isinstance(email, str):
email = email.lower()
elif isinstance(email, list):
email = email[0].lower()
else:
email = str(email).lower()
cn = str(entry["cn"])
user_dn = entry.entry_dn

View File

@@ -579,7 +579,12 @@ async def clone_chat_by_id(
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_share_id(id)
if user.role == "admin":
chat = Chats.get_chat_by_id(id)
else:
chat = Chats.get_chat_by_share_id(id)
if chat:
updated_chat = {
**chat.chat,

View File

@@ -159,7 +159,6 @@ async def create_new_knowledge(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_EXISTS,
)
############################
@@ -168,20 +167,17 @@ async def create_new_knowledge(
@router.post("/reindex", response_model=bool)
async def reindex_knowledge_files(
request: Request,
user=Depends(get_verified_user)
):
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
knowledge_bases = Knowledges.get_knowledge_bases()
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
for knowledge_base in knowledge_bases:
try:
files = Files.get_files_by_ids(knowledge_base.data.get("file_ids", []))
@@ -195,34 +191,40 @@ async def reindex_knowledge_files(
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error deleting vector DB collection"
detail=f"Error deleting vector DB collection",
)
failed_files = []
for file in files:
try:
process_file(
request,
ProcessFileForm(file_id=file.id, collection_name=knowledge_base.id),
ProcessFileForm(
file_id=file.id, collection_name=knowledge_base.id
),
user=user,
)
except Exception as e:
log.error(f"Error processing file {file.filename} (ID: {file.id}): {str(e)}")
log.error(
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
)
failed_files.append({"file_id": file.id, "error": str(e)})
continue
except Exception as e:
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing knowledge base"
detail=f"Error processing knowledge base",
)
if failed_files:
log.warning(f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}")
log.warning(
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
)
for failed in failed_files:
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
log.info("Reindexing completed successfully")
return True
@@ -742,6 +744,3 @@ def add_files_to_knowledge_batch(
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
)

View File

@@ -153,7 +153,9 @@ async def update_memory_by_id(
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
):
memory = Memories.update_memory_by_id(memory_id, form_data.content)
memory = Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content
)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")

File diff suppressed because it is too large Load Diff

View File

@@ -88,6 +88,10 @@ class ChatPermissions(BaseModel):
file_upload: bool = True
delete: bool = True
edit: bool = True
stt: bool = True
tts: bool = True
call: bool = True
multiple_models: bool = True
temporary: bool = True
temporary_enforced: bool = False

View File

@@ -9,9 +9,8 @@ from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
from open_webui.utils.redis import (
parse_redis_sentinel_url,
get_sentinels_from_env,
AsyncRedisSentinelManager,
get_sentinel_url_from_env,
)
from open_webui.env import (
@@ -38,15 +37,10 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS:
redis_config = parse_redis_sentinel_url(WEBSOCKET_REDIS_URL)
mgr = AsyncRedisSentinelManager(
WEBSOCKET_SENTINEL_HOSTS.split(","),
sentinel_port=int(WEBSOCKET_SENTINEL_PORT),
redis_port=redis_config["port"],
service=redis_config["service"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
mgr = socketio.AsyncRedisManager(
get_sentinel_url_from_env(
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
)
else:
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
@@ -345,16 +339,17 @@ def get_event_emitter(request_info, update_db=True):
request_info["message_id"],
)
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
if message:
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")

View File

@@ -5,16 +5,23 @@ from uuid import uuid4
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
def cleanup_task(task_id: str):
def cleanup_task(task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary
if id and task_id in chat_tasks.get(id, []):
chat_tasks[id].remove(task_id)
if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
chat_tasks.pop(id, None)
def create_task(coroutine):
def create_task(coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -22,9 +29,15 @@ def create_task(coroutine):
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id))
task.add_done_callback(lambda t: cleanup_task(task_id, id))
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
if chat_tasks.get(id):
chat_tasks[id].append(task_id)
else:
chat_tasks[id] = [task_id]
return task_id, task
@@ -42,6 +55,13 @@ def list_tasks():
return list(tasks.keys())
def list_task_ids_by_chat_id(id):
"""
List all tasks associated with a specific ID.
"""
return chat_tasks.get(id, [])
async def stop_task(task_id: str):
"""
Cancel a running task and remove it from the global task list.

View File

@@ -235,46 +235,30 @@ async def chat_completion_tools_handler(
if isinstance(tool_result, str):
tool = tools[tool_function_name]
tool_id = tool.get("tool_id", "")
tool_name = (
f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
)
if tool.get("metadata", {}).get("citation", False) or tool.get(
"direct", False
):
# Citation is enabled for this tool
sources.append(
{
"source": {
"name": (
f"TOOL:" + f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
),
"name": (f"TOOL:{tool_name}"),
},
"document": [tool_result, *tool_result_files],
"metadata": [
{
"source": (
f"TOOL:" + f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
)
}
],
"document": [tool_result],
"metadata": [{"source": (f"TOOL:{tool_name}")}],
}
)
else:
sources.append(
{
"source": {},
"document": [tool_result, *tool_result_files],
"metadata": [
{
"source": (
f"TOOL:" + f"{tool_id}/{tool_function_name}"
if tool_id
else f"{tool_function_name}"
)
}
],
}
# Citation is not enabled for this tool
body["messages"] = add_or_update_user_message(
f"\nTool `{tool_name}` Output: {tool_result}",
body["messages"],
)
if (
@@ -550,13 +534,20 @@ async def chat_image_generation_handler(
}
)
for image in images:
await __event_emitter__(
{
"type": "message",
"data": {"content": f"![Generated Image]({image['url']})\n"},
}
)
await __event_emitter__(
{
"type": "files",
"data": {
"files": [
{
"type": "image",
"url": image["url"],
}
for image in images
]
},
}
)
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
except Exception as e:
@@ -2261,7 +2252,9 @@ async def process_chat_response(
await response.background()
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(post_response_handler(response, events))
task_id, _ = create_task(
post_response_handler(response, events), id=metadata["chat_id"]
)
return {"status": True, "task_id": task_id}
else:

View File

@@ -114,9 +114,12 @@ async def get_all_models(request, user: UserModel = None):
for custom_model in custom_models:
if custom_model.base_model_id is None:
for model in models:
if (
custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0]
if custom_model.id == model["id"] or (
model.get("owned_by") == "ollama"
and custom_model.id
== model["id"].split(":")[
0
] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
):
if custom_model.is_active:
model["name"] = custom_model.name

View File

@@ -4,7 +4,7 @@ from redis import asyncio as aioredis
from urllib.parse import urlparse
def parse_redis_sentinel_url(redis_url):
def parse_redis_service_url(redis_url):
parsed_url = urlparse(redis_url)
if parsed_url.scheme != "redis":
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
@@ -20,7 +20,7 @@ def parse_redis_sentinel_url(redis_url):
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
if redis_sentinels:
redis_config = parse_redis_sentinel_url(redis_url)
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
@@ -45,65 +45,14 @@ def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
return []
class AsyncRedisSentinelManager(socketio.AsyncRedisManager):
def __init__(
self,
sentinel_hosts,
sentinel_port=26379,
redis_port=6379,
service="mymaster",
db=0,
username=None,
password=None,
channel="socketio",
write_only=False,
logger=None,
redis_options=None,
):
"""
Initialize the Redis Sentinel Manager.
This implementation mostly replicates the __init__ of AsyncRedisManager and
overrides _redis_connect() with a version that uses Redis Sentinel
:param sentinel_hosts: List of Sentinel hosts
:param sentinel_port: Sentinel Port
:param redis_port: Redis Port (currently unsupported by aioredis!)
:param service: Master service name in Sentinel
:param db: Redis database to use
:param username: Redis username (if any) (currently unsupported by aioredis!)
:param password: Redis password (if any)
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param redis_options: additional keyword arguments to be passed to
``aioredis.from_url()``.
"""
self._sentinels = [(host, sentinel_port) for host in sentinel_hosts]
self._redis_port = redis_port
self._service = service
self._db = db
self._username = username
self._password = password
self._channel = channel
self.redis_options = redis_options or {}
# connect and call grandparent constructor
self._redis_connect()
super(socketio.AsyncRedisManager, self).__init__(
channel=channel, write_only=write_only, logger=logger
)
def _redis_connect(self):
"""Establish connections to Redis through Sentinel."""
sentinel = aioredis.sentinel.Sentinel(
self._sentinels,
port=self._redis_port,
db=self._db,
password=self._password,
**self.redis_options,
)
self.redis = sentinel.master_for(self._service)
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
redis_config = parse_redis_service_url(redis_url)
username = redis_config["username"] or ""
password = redis_config["password"] or ""
auth_part = ""
if username or password:
auth_part = f"{username}:{password}@"
hosts_part = ",".join(
f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
)
return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"

View File

@@ -152,6 +152,8 @@ def rag_template(template: str, context: str, query: str):
if template.strip() == "":
template = DEFAULT_RAG_TEMPLATE
template = prompt_template(template)
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."