mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into main
This commit is contained in:
@@ -76,11 +76,11 @@ def serve(
|
||||
from open_webui.env import UVICORN_WORKERS # Import the workers setting
|
||||
|
||||
uvicorn.run(
|
||||
open_webui.main.app,
|
||||
host=host,
|
||||
port=port,
|
||||
open_webui.main.app,
|
||||
host=host,
|
||||
port=port,
|
||||
forwarded_allow_ips="*",
|
||||
workers=UVICORN_WORKERS
|
||||
workers=UVICORN_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -201,7 +201,10 @@ def save_config(config):
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
ENABLE_PERSISTENT_CONFIG = os.environ.get("ENABLE_PERSISTENT_CONFIG", "True").lower() == "true"
|
||||
ENABLE_PERSISTENT_CONFIG = (
|
||||
os.environ.get("ENABLE_PERSISTENT_CONFIG", "True").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
class PersistentConfig(Generic[T]):
|
||||
def __init__(self, env_name: str, config_path: str, env_value: T):
|
||||
@@ -612,10 +615,16 @@ def load_oauth_providers():
|
||||
"scope": OAUTH_SCOPES.value,
|
||||
}
|
||||
|
||||
if OAUTH_CODE_CHALLENGE_METHOD.value and OAUTH_CODE_CHALLENGE_METHOD.value == "S256":
|
||||
if (
|
||||
OAUTH_CODE_CHALLENGE_METHOD.value
|
||||
and OAUTH_CODE_CHALLENGE_METHOD.value == "S256"
|
||||
):
|
||||
client_kwargs["code_challenge_method"] = "S256"
|
||||
elif OAUTH_CODE_CHALLENGE_METHOD.value:
|
||||
raise Exception('Code challenge methods other than "%s" not supported. Given: "%s"' % ("S256", OAUTH_CODE_CHALLENGE_METHOD.value))
|
||||
raise Exception(
|
||||
'Code challenge methods other than "%s" not supported. Given: "%s"'
|
||||
% ("S256", OAUTH_CODE_CHALLENGE_METHOD.value)
|
||||
)
|
||||
|
||||
client.register(
|
||||
name="oidc",
|
||||
@@ -1053,6 +1062,22 @@ USER_PERMISSIONS_CHAT_EDIT = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_STT = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_STT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_TTS = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_TTS", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_CALL = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_CALL", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_MULTIPLE_MODELS", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_TEMPORARY = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true"
|
||||
)
|
||||
@@ -1062,6 +1087,7 @@ USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED = (
|
||||
== "true"
|
||||
)
|
||||
|
||||
|
||||
USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS = (
|
||||
os.environ.get("USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS", "False").lower()
|
||||
== "true"
|
||||
@@ -1100,6 +1126,10 @@ DEFAULT_USER_PERMISSIONS = {
|
||||
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
|
||||
"delete": USER_PERMISSIONS_CHAT_DELETE,
|
||||
"edit": USER_PERMISSIONS_CHAT_EDIT,
|
||||
"stt": USER_PERMISSIONS_CHAT_STT,
|
||||
"tts": USER_PERMISSIONS_CHAT_TTS,
|
||||
"call": USER_PERMISSIONS_CHAT_CALL,
|
||||
"multiple_models": USER_PERMISSIONS_CHAT_MULTIPLE_MODELS,
|
||||
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
|
||||
"temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED,
|
||||
},
|
||||
@@ -1820,12 +1850,6 @@ RAG_FILE_MAX_SIZE = PersistentConfig(
|
||||
),
|
||||
)
|
||||
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
|
||||
"rag.enable_web_loader_ssl_verification",
|
||||
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_ENGINE = PersistentConfig(
|
||||
"RAG_EMBEDDING_ENGINE",
|
||||
"rag.embedding_engine",
|
||||
@@ -1990,16 +2014,20 @@ YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
|
||||
"ENABLE_RAG_WEB_SEARCH",
|
||||
####################################
|
||||
# Web Search (RAG)
|
||||
####################################
|
||||
|
||||
ENABLE_WEB_SEARCH = PersistentConfig(
|
||||
"ENABLE_WEB_SEARCH",
|
||||
"rag.web.search.enable",
|
||||
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
|
||||
os.getenv("ENABLE_WEB_SEARCH", "False").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_ENGINE",
|
||||
WEB_SEARCH_ENGINE = PersistentConfig(
|
||||
"WEB_SEARCH_ENGINE",
|
||||
"rag.web.search.engine",
|
||||
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
||||
os.getenv("WEB_SEARCH_ENGINE", ""),
|
||||
)
|
||||
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
@@ -2008,10 +2036,18 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"WEB_SEARCH_RESULT_COUNT",
|
||||
"rag.web.search.result_count",
|
||||
int(os.getenv("WEB_SEARCH_RESULT_COUNT", "3")),
|
||||
)
|
||||
|
||||
|
||||
# You can provide a list of your own websites to filter after performing a web search.
|
||||
# This ensures the highest level of safety and reliability of the information sources.
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
||||
WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST",
|
||||
"rag.web.search.domain.filter_list",
|
||||
[
|
||||
# "wikipedia.com",
|
||||
@@ -2020,6 +2056,30 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
],
|
||||
)
|
||||
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS",
|
||||
"rag.web.search.concurrent_requests",
|
||||
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
WEB_LOADER_ENGINE = PersistentConfig(
|
||||
"WEB_LOADER_ENGINE",
|
||||
"rag.web.loader.engine",
|
||||
os.environ.get("WEB_LOADER_ENGINE", ""),
|
||||
)
|
||||
|
||||
ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
|
||||
"ENABLE_WEB_LOADER_SSL_VERIFICATION",
|
||||
"rag.web.loader.ssl_verification",
|
||||
os.environ.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
WEB_SEARCH_TRUST_ENV = PersistentConfig(
|
||||
"WEB_SEARCH_TRUST_ENV",
|
||||
"rag.web.search.trust_env",
|
||||
os.getenv("WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
SEARXNG_QUERY_URL = PersistentConfig(
|
||||
"SEARXNG_QUERY_URL",
|
||||
@@ -2087,18 +2147,6 @@ SERPLY_API_KEY = PersistentConfig(
|
||||
os.getenv("SERPLY_API_KEY", ""),
|
||||
)
|
||||
|
||||
TAVILY_API_KEY = PersistentConfig(
|
||||
"TAVILY_API_KEY",
|
||||
"rag.web.search.tavily_api_key",
|
||||
os.getenv("TAVILY_API_KEY", ""),
|
||||
)
|
||||
|
||||
TAVILY_EXTRACT_DEPTH = PersistentConfig(
|
||||
"TAVILY_EXTRACT_DEPTH",
|
||||
"rag.web.search.tavily_extract_depth",
|
||||
os.getenv("TAVILY_EXTRACT_DEPTH", "basic"),
|
||||
)
|
||||
|
||||
JINA_API_KEY = PersistentConfig(
|
||||
"JINA_API_KEY",
|
||||
"rag.web.search.jina_api_key",
|
||||
@@ -2167,54 +2215,43 @@ SOUGOU_API_SK = PersistentConfig(
|
||||
os.getenv("SOUGOU_API_SK", ""),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_RESULT_COUNT",
|
||||
"rag.web.search.result_count",
|
||||
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
|
||||
TAVILY_API_KEY = PersistentConfig(
|
||||
"TAVILY_API_KEY",
|
||||
"rag.web.search.tavily_api_key",
|
||||
os.getenv("TAVILY_API_KEY", ""),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
|
||||
"rag.web.search.concurrent_requests",
|
||||
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
TAVILY_EXTRACT_DEPTH = PersistentConfig(
|
||||
"TAVILY_EXTRACT_DEPTH",
|
||||
"rag.web.search.tavily_extract_depth",
|
||||
os.getenv("TAVILY_EXTRACT_DEPTH", "basic"),
|
||||
)
|
||||
|
||||
RAG_WEB_LOADER_ENGINE = PersistentConfig(
|
||||
"RAG_WEB_LOADER_ENGINE",
|
||||
"rag.web.loader.engine",
|
||||
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_TRUST_ENV",
|
||||
"rag.web.search.trust_env",
|
||||
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
|
||||
)
|
||||
|
||||
PLAYWRIGHT_WS_URI = PersistentConfig(
|
||||
"PLAYWRIGHT_WS_URI",
|
||||
"rag.web.loader.engine.playwright.ws.uri",
|
||||
os.environ.get("PLAYWRIGHT_WS_URI", None),
|
||||
PLAYWRIGHT_WS_URL = PersistentConfig(
|
||||
"PLAYWRIGHT_WS_URL",
|
||||
"rag.web.loader.playwright_ws_url",
|
||||
os.environ.get("PLAYWRIGHT_WS_URL", ""),
|
||||
)
|
||||
|
||||
PLAYWRIGHT_TIMEOUT = PersistentConfig(
|
||||
"PLAYWRIGHT_TIMEOUT",
|
||||
"rag.web.loader.engine.playwright.timeout",
|
||||
int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10")),
|
||||
"rag.web.loader.playwright_timeout",
|
||||
int(os.environ.get("PLAYWRIGHT_TIMEOUT", "10000")),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_KEY = PersistentConfig(
|
||||
"FIRECRAWL_API_KEY",
|
||||
"firecrawl.api_key",
|
||||
"rag.web.loader.firecrawl_api_key",
|
||||
os.environ.get("FIRECRAWL_API_KEY", ""),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_BASE_URL = PersistentConfig(
|
||||
"FIRECRAWL_API_BASE_URL",
|
||||
"firecrawl.api_url",
|
||||
"rag.web.loader.firecrawl_api_url",
|
||||
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Images
|
||||
####################################
|
||||
@@ -2467,6 +2504,13 @@ WHISPER_MODEL_AUTO_UPDATE = (
|
||||
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
WHISPER_VAD_FILTER = PersistentConfig(
|
||||
"WHISPER_VAD_FILTER",
|
||||
"audio.stt.whisper_vad_filter",
|
||||
os.getenv("WHISPER_VAD_FILTER", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
# Add Deepgram configuration
|
||||
DEEPGRAM_API_KEY = PersistentConfig(
|
||||
"DEEPGRAM_API_KEY",
|
||||
@@ -2474,6 +2518,7 @@ DEEPGRAM_API_KEY = PersistentConfig(
|
||||
os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
)
|
||||
|
||||
|
||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||
"audio.stt.openai.api_base_url",
|
||||
|
||||
@@ -498,4 +498,4 @@ PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
|
||||
# PROGRESSIVE WEB APP OPTIONS
|
||||
####################################
|
||||
|
||||
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
|
||||
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
|
||||
|
||||
@@ -160,12 +160,13 @@ from open_webui.config import (
|
||||
AUDIO_TTS_VOICE,
|
||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
PLAYWRIGHT_WS_URL,
|
||||
PLAYWRIGHT_TIMEOUT,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
WEB_LOADER_ENGINE,
|
||||
WHISPER_MODEL,
|
||||
WHISPER_VAD_FILTER,
|
||||
DEEPGRAM_API_KEY,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
@@ -205,12 +206,13 @@ from open_webui.config import (
|
||||
YOUTUBE_LOADER_LANGUAGE,
|
||||
YOUTUBE_LOADER_PROXY_URL,
|
||||
# Retrieval (Web Search)
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
ENABLE_WEB_SEARCH,
|
||||
WEB_SEARCH_ENGINE,
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
RAG_WEB_SEARCH_TRUST_ENV,
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
WEB_SEARCH_RESULT_COUNT,
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
WEB_SEARCH_TRUST_ENV,
|
||||
WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
JINA_API_KEY,
|
||||
SEARCHAPI_API_KEY,
|
||||
SEARCHAPI_ENGINE,
|
||||
@@ -240,8 +242,7 @@ from open_webui.config import (
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
ENABLE_ONEDRIVE_INTEGRATION,
|
||||
UPLOAD_DIR,
|
||||
@@ -373,7 +374,11 @@ from open_webui.utils.auth import (
|
||||
from open_webui.utils.oauth import OAuthManager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||
from open_webui.tasks import (
|
||||
list_task_ids_by_chat_id,
|
||||
stop_task,
|
||||
list_tasks,
|
||||
) # Import from tasks.py
|
||||
|
||||
from open_webui.utils.redis import get_sentinels_from_env
|
||||
|
||||
@@ -596,9 +601,7 @@ app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
@@ -631,12 +634,16 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
||||
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
|
||||
|
||||
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
app.state.config.ENABLE_WEB_SEARCH = ENABLE_WEB_SEARCH
|
||||
app.state.config.WEB_SEARCH_ENGINE = WEB_SEARCH_ENGINE
|
||||
app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
app.state.config.WEB_SEARCH_RESULT_COUNT = WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE
|
||||
app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||
@@ -664,11 +671,8 @@ app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
||||
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
||||
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
|
||||
|
||||
app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
|
||||
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
|
||||
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
|
||||
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||
@@ -788,6 +792,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||
app.state.config.WHISPER_VAD_FILTER = WHISPER_VAD_FILTER
|
||||
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
|
||||
|
||||
app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY
|
||||
@@ -1022,14 +1027,19 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
|
||||
continue
|
||||
|
||||
model_tags = [
|
||||
tag.get("name")
|
||||
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
|
||||
]
|
||||
tags = [tag.get("name") for tag in model.get("tags", [])]
|
||||
try:
|
||||
model_tags = [
|
||||
tag.get("name")
|
||||
for tag in model.get("info", {}).get("meta", {}).get("tags", [])
|
||||
]
|
||||
tags = [tag.get("name") for tag in model.get("tags", [])]
|
||||
|
||||
tags = list(set(model_tags + tags))
|
||||
model["tags"] = [{"name": tag} for tag in tags]
|
||||
tags = list(set(model_tags + tags))
|
||||
model["tags"] = [{"name": tag} for tag in tags]
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing model tags: {e}")
|
||||
model["tags"] = []
|
||||
pass
|
||||
|
||||
models.append(model)
|
||||
|
||||
@@ -1199,7 +1209,7 @@ async def chat_action(
|
||||
@app.post("/api/tasks/stop/{task_id}")
|
||||
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
try:
|
||||
result = await stop_task(task_id) # Use the function from tasks.py
|
||||
result = await stop_task(task_id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
@@ -1207,7 +1217,19 @@ async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
|
||||
@app.get("/api/tasks")
|
||||
async def list_tasks_endpoint(user=Depends(get_verified_user)):
|
||||
return {"tasks": list_tasks()} # Use the function from tasks.py
|
||||
return {"tasks": list_tasks()}
|
||||
|
||||
|
||||
@app.get("/api/tasks/chat/{chat_id}")
|
||||
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
if chat is None or chat.user_id != user.id:
|
||||
return {"task_ids": []}
|
||||
|
||||
task_ids = list_task_ids_by_chat_id(chat_id)
|
||||
|
||||
print(f"Task IDs for chat {chat_id}: {task_ids}")
|
||||
return {"task_ids": task_ids}
|
||||
|
||||
|
||||
##################################
|
||||
@@ -1263,7 +1285,7 @@ async def get_app_config(request: Request):
|
||||
{
|
||||
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
|
||||
"enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION,
|
||||
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
|
||||
@@ -63,14 +63,15 @@ class MemoriesTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_memory_by_id(
|
||||
def update_memory_by_id_and_user_id(
|
||||
self,
|
||||
id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id).update(
|
||||
db.query(Memory).filter_by(id=id, user_id=user_id).update(
|
||||
{"content": content, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
@@ -297,7 +297,9 @@ def query_collection_with_hybrid_search(
|
||||
collection_results = {}
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
log.debug(f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}")
|
||||
log.debug(
|
||||
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
|
||||
)
|
||||
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=collection_name
|
||||
)
|
||||
@@ -619,7 +621,9 @@ def generate_openai_batch_embeddings(
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}")
|
||||
log.debug(
|
||||
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
@@ -662,7 +666,9 @@ def generate_ollama_batch_embeddings(
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}")
|
||||
log.debug(
|
||||
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
@@ -28,9 +28,9 @@ from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
PLAYWRIGHT_WS_URL,
|
||||
PLAYWRIGHT_TIMEOUT,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
WEB_LOADER_ENGINE,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
TAVILY_API_KEY,
|
||||
@@ -584,13 +584,6 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
|
||||
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
|
||||
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
|
||||
RAG_WEB_LOADER_ENGINES["tavily"] = SafeTavilyLoader
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
@@ -608,27 +601,36 @@ def get_web_loader(
|
||||
"trust_env": trust_env,
|
||||
}
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "playwright":
|
||||
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
|
||||
WebLoaderClass = SafeWebBaseLoader
|
||||
if WEB_LOADER_ENGINE.value == "playwright":
|
||||
WebLoaderClass = SafePlaywrightURLLoader
|
||||
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
|
||||
if PLAYWRIGHT_WS_URL.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
if WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
WebLoaderClass = SafeFireCrawlLoader
|
||||
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "tavily":
|
||||
if WEB_LOADER_ENGINE.value == "tavily":
|
||||
WebLoaderClass = SafeTavilyLoader
|
||||
web_loader_args["api_key"] = TAVILY_API_KEY.value
|
||||
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
|
||||
|
||||
# Create the appropriate WebLoader based on the configuration
|
||||
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
if WebLoaderClass:
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
log.debug(
|
||||
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
log.debug(
|
||||
"Using WEB_LOADER_ENGINE %s for %s URLs",
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
|
||||
return web_loader
|
||||
return web_loader
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
|
||||
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
|
||||
)
|
||||
|
||||
@@ -330,7 +330,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
@@ -384,7 +384,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
@@ -440,7 +440,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
status_code=getattr(r, "status", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
@@ -497,7 +497,11 @@ def transcribe(request: Request, file_path):
|
||||
)
|
||||
|
||||
model = request.app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
segments, info = model.transcribe(
|
||||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
@@ -624,10 +628,7 @@ def transcribe(request: Request, file_path):
|
||||
elif request.app.state.config.STT_ENGINE == "azure":
|
||||
# Check file exists and size
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Audio file not found"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Audio file not found")
|
||||
|
||||
# Check file size (Azure has a larger limit of 200MB)
|
||||
file_size = os.path.getsize(file_path)
|
||||
@@ -643,11 +644,22 @@ def transcribe(request: Request, file_path):
|
||||
|
||||
# IF NO LOCALES, USE DEFAULTS
|
||||
if len(locales) < 2:
|
||||
locales = ['en-US', 'es-ES', 'es-MX', 'fr-FR', 'hi-IN',
|
||||
'it-IT','de-DE', 'en-GB', 'en-IN', 'ja-JP',
|
||||
'ko-KR', 'pt-BR', 'zh-CN']
|
||||
locales = ','.join(locales)
|
||||
|
||||
locales = [
|
||||
"en-US",
|
||||
"es-ES",
|
||||
"es-MX",
|
||||
"fr-FR",
|
||||
"hi-IN",
|
||||
"it-IT",
|
||||
"de-DE",
|
||||
"en-GB",
|
||||
"en-IN",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"pt-BR",
|
||||
"zh-CN",
|
||||
]
|
||||
locales = ",".join(locales)
|
||||
|
||||
if not api_key or not region:
|
||||
raise HTTPException(
|
||||
@@ -658,22 +670,26 @@ def transcribe(request: Request, file_path):
|
||||
r = None
|
||||
try:
|
||||
# Prepare the request
|
||||
data = {'definition': json.dumps({
|
||||
"locales": locales.split(','),
|
||||
"diarization": {"maxSpeakers": 3,"enabled": True}
|
||||
} if locales else {}
|
||||
)
|
||||
data = {
|
||||
"definition": json.dumps(
|
||||
{
|
||||
"locales": locales.split(","),
|
||||
"diarization": {"maxSpeakers": 3, "enabled": True},
|
||||
}
|
||||
if locales
|
||||
else {}
|
||||
)
|
||||
}
|
||||
url = f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
|
||||
|
||||
|
||||
# Use context manager to ensure file is properly closed
|
||||
with open(file_path, 'rb') as audio_file:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
r = requests.post(
|
||||
url=url,
|
||||
files={'audio': audio_file},
|
||||
files={"audio": audio_file},
|
||||
data=data,
|
||||
headers={
|
||||
'Ocp-Apim-Subscription-Key': api_key,
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -681,11 +697,11 @@ def transcribe(request: Request, file_path):
|
||||
response = r.json()
|
||||
|
||||
# Extract transcript from response
|
||||
if not response.get('combinedPhrases'):
|
||||
if not response.get("combinedPhrases"):
|
||||
raise ValueError("No transcription found in response")
|
||||
|
||||
# Get the full transcript from combinedPhrases
|
||||
transcript = response['combinedPhrases'][0].get('text', '').strip()
|
||||
transcript = response["combinedPhrases"][0].get("text", "").strip()
|
||||
if not transcript:
|
||||
raise ValueError("Empty transcript in response")
|
||||
|
||||
@@ -718,7 +734,7 @@ def transcribe(request: Request, file_path):
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, 'status_code', 500) if r else 500,
|
||||
status_code=getattr(r, "status_code", 500) if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
|
||||
@@ -231,13 +231,15 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
|
||||
entry = connection_app.entries[0]
|
||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||
email = entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]
|
||||
email = entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"].value # retrive the Attribute value
|
||||
if not email:
|
||||
raise HTTPException(400, "User does not have a valid email address.")
|
||||
elif isinstance(email, str):
|
||||
email = email.lower()
|
||||
elif isinstance(email, list):
|
||||
email = email[0].lower()
|
||||
else:
|
||||
email = str(email).lower()
|
||||
|
||||
cn = str(entry["cn"])
|
||||
user_dn = entry.entry_dn
|
||||
|
||||
@@ -579,7 +579,12 @@ async def clone_chat_by_id(
|
||||
|
||||
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
|
||||
if user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
else:
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
|
||||
@@ -159,7 +159,6 @@ async def create_new_knowledge(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_EXISTS,
|
||||
)
|
||||
|
||||
|
||||
|
||||
############################
|
||||
@@ -168,20 +167,17 @@ async def create_new_knowledge(
|
||||
|
||||
|
||||
@router.post("/reindex", response_model=bool)
|
||||
async def reindex_knowledge_files(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
|
||||
|
||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
try:
|
||||
files = Files.get_files_by_ids(knowledge_base.data.get("file_ids", []))
|
||||
@@ -195,34 +191,40 @@ async def reindex_knowledge_files(
|
||||
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error deleting vector DB collection"
|
||||
detail=f"Error deleting vector DB collection",
|
||||
)
|
||||
|
||||
|
||||
failed_files = []
|
||||
for file in files:
|
||||
try:
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=file.id, collection_name=knowledge_base.id),
|
||||
ProcessFileForm(
|
||||
file_id=file.id, collection_name=knowledge_base.id
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error processing file {file.filename} (ID: {file.id}): {str(e)}")
|
||||
log.error(
|
||||
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
|
||||
)
|
||||
failed_files.append({"file_id": file.id, "error": str(e)})
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error processing knowledge base"
|
||||
detail=f"Error processing knowledge base",
|
||||
)
|
||||
|
||||
|
||||
if failed_files:
|
||||
log.warning(f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}")
|
||||
log.warning(
|
||||
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
|
||||
)
|
||||
for failed in failed_files:
|
||||
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
||||
|
||||
|
||||
log.info("Reindexing completed successfully")
|
||||
return True
|
||||
|
||||
@@ -742,6 +744,3 @@ def add_files_to_knowledge_batch(
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -153,7 +153,9 @@ async def update_memory_by_id(
|
||||
form_data: MemoryUpdateModel,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
memory = Memories.update_memory_by_id(memory_id, form_data.content)
|
||||
memory = Memories.update_memory_by_id_and_user_id(
|
||||
memory_id, user.id, form_data.content
|
||||
)
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -88,6 +88,10 @@ class ChatPermissions(BaseModel):
|
||||
file_upload: bool = True
|
||||
delete: bool = True
|
||||
edit: bool = True
|
||||
stt: bool = True
|
||||
tts: bool = True
|
||||
call: bool = True
|
||||
multiple_models: bool = True
|
||||
temporary: bool = True
|
||||
temporary_enforced: bool = False
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ from open_webui.models.users import Users, UserNameResponse
|
||||
from open_webui.models.channels import Channels
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.utils.redis import (
|
||||
parse_redis_sentinel_url,
|
||||
get_sentinels_from_env,
|
||||
AsyncRedisSentinelManager,
|
||||
get_sentinel_url_from_env,
|
||||
)
|
||||
|
||||
from open_webui.env import (
|
||||
@@ -38,15 +37,10 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
if WEBSOCKET_SENTINEL_HOSTS:
|
||||
redis_config = parse_redis_sentinel_url(WEBSOCKET_REDIS_URL)
|
||||
mgr = AsyncRedisSentinelManager(
|
||||
WEBSOCKET_SENTINEL_HOSTS.split(","),
|
||||
sentinel_port=int(WEBSOCKET_SENTINEL_PORT),
|
||||
redis_port=redis_config["port"],
|
||||
service=redis_config["service"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
mgr = socketio.AsyncRedisManager(
|
||||
get_sentinel_url_from_env(
|
||||
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||
)
|
||||
)
|
||||
else:
|
||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
||||
@@ -345,16 +339,17 @@ def get_event_emitter(request_info, update_db=True):
|
||||
request_info["message_id"],
|
||||
)
|
||||
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
if message:
|
||||
content = message.get("content", "")
|
||||
content += event_data.get("data", {}).get("content", "")
|
||||
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
request_info["chat_id"],
|
||||
request_info["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
if "type" in event_data and event_data["type"] == "replace":
|
||||
content = event_data.get("data", {}).get("content", "")
|
||||
|
||||
@@ -5,16 +5,23 @@ from uuid import uuid4
|
||||
|
||||
# A dictionary to keep track of active tasks
|
||||
tasks: Dict[str, asyncio.Task] = {}
|
||||
chat_tasks = {}
|
||||
|
||||
|
||||
def cleanup_task(task_id: str):
|
||||
def cleanup_task(task_id: str, id=None):
|
||||
"""
|
||||
Remove a completed or canceled task from the global `tasks` dictionary.
|
||||
"""
|
||||
tasks.pop(task_id, None) # Remove the task if it exists
|
||||
|
||||
# If an ID is provided, remove the task from the chat_tasks dictionary
|
||||
if id and task_id in chat_tasks.get(id, []):
|
||||
chat_tasks[id].remove(task_id)
|
||||
if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
|
||||
chat_tasks.pop(id, None)
|
||||
|
||||
def create_task(coroutine):
|
||||
|
||||
def create_task(coroutine, id=None):
|
||||
"""
|
||||
Create a new asyncio task and add it to the global task dictionary.
|
||||
"""
|
||||
@@ -22,9 +29,15 @@ def create_task(coroutine):
|
||||
task = asyncio.create_task(coroutine) # Create the task
|
||||
|
||||
# Add a done callback for cleanup
|
||||
task.add_done_callback(lambda t: cleanup_task(task_id))
|
||||
|
||||
task.add_done_callback(lambda t: cleanup_task(task_id, id))
|
||||
tasks[task_id] = task
|
||||
|
||||
# If an ID is provided, associate the task with that ID
|
||||
if chat_tasks.get(id):
|
||||
chat_tasks[id].append(task_id)
|
||||
else:
|
||||
chat_tasks[id] = [task_id]
|
||||
|
||||
return task_id, task
|
||||
|
||||
|
||||
@@ -42,6 +55,13 @@ def list_tasks():
|
||||
return list(tasks.keys())
|
||||
|
||||
|
||||
def list_task_ids_by_chat_id(id):
|
||||
"""
|
||||
List all tasks associated with a specific ID.
|
||||
"""
|
||||
return chat_tasks.get(id, [])
|
||||
|
||||
|
||||
async def stop_task(task_id: str):
|
||||
"""
|
||||
Cancel a running task and remove it from the global task list.
|
||||
|
||||
@@ -235,46 +235,30 @@ async def chat_completion_tools_handler(
|
||||
if isinstance(tool_result, str):
|
||||
tool = tools[tool_function_name]
|
||||
tool_id = tool.get("tool_id", "")
|
||||
|
||||
tool_name = (
|
||||
f"{tool_id}/{tool_function_name}"
|
||||
if tool_id
|
||||
else f"{tool_function_name}"
|
||||
)
|
||||
if tool.get("metadata", {}).get("citation", False) or tool.get(
|
||||
"direct", False
|
||||
):
|
||||
|
||||
# Citation is enabled for this tool
|
||||
sources.append(
|
||||
{
|
||||
"source": {
|
||||
"name": (
|
||||
f"TOOL:" + f"{tool_id}/{tool_function_name}"
|
||||
if tool_id
|
||||
else f"{tool_function_name}"
|
||||
),
|
||||
"name": (f"TOOL:{tool_name}"),
|
||||
},
|
||||
"document": [tool_result, *tool_result_files],
|
||||
"metadata": [
|
||||
{
|
||||
"source": (
|
||||
f"TOOL:" + f"{tool_id}/{tool_function_name}"
|
||||
if tool_id
|
||||
else f"{tool_function_name}"
|
||||
)
|
||||
}
|
||||
],
|
||||
"document": [tool_result],
|
||||
"metadata": [{"source": (f"TOOL:{tool_name}")}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
sources.append(
|
||||
{
|
||||
"source": {},
|
||||
"document": [tool_result, *tool_result_files],
|
||||
"metadata": [
|
||||
{
|
||||
"source": (
|
||||
f"TOOL:" + f"{tool_id}/{tool_function_name}"
|
||||
if tool_id
|
||||
else f"{tool_function_name}"
|
||||
)
|
||||
}
|
||||
],
|
||||
}
|
||||
# Citation is not enabled for this tool
|
||||
body["messages"] = add_or_update_user_message(
|
||||
f"\nTool `{tool_name}` Output: {tool_result}",
|
||||
body["messages"],
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -550,13 +534,20 @@ async def chat_image_generation_handler(
|
||||
}
|
||||
)
|
||||
|
||||
for image in images:
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "message",
|
||||
"data": {"content": f"\n"},
|
||||
}
|
||||
)
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "files",
|
||||
"data": {
|
||||
"files": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": image["url"],
|
||||
}
|
||||
for image in images
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
||||
except Exception as e:
|
||||
@@ -2261,7 +2252,9 @@ async def process_chat_response(
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(post_response_handler, response, events)
|
||||
task_id, _ = create_task(post_response_handler(response, events))
|
||||
task_id, _ = create_task(
|
||||
post_response_handler(response, events), id=metadata["chat_id"]
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
else:
|
||||
|
||||
@@ -114,9 +114,12 @@ async def get_all_models(request, user: UserModel = None):
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id is None:
|
||||
for model in models:
|
||||
if (
|
||||
custom_model.id == model["id"]
|
||||
or custom_model.id == model["id"].split(":")[0]
|
||||
if custom_model.id == model["id"] or (
|
||||
model.get("owned_by") == "ollama"
|
||||
and custom_model.id
|
||||
== model["id"].split(":")[
|
||||
0
|
||||
] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
|
||||
):
|
||||
if custom_model.is_active:
|
||||
model["name"] = custom_model.name
|
||||
|
||||
@@ -4,7 +4,7 @@ from redis import asyncio as aioredis
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def parse_redis_sentinel_url(redis_url):
|
||||
def parse_redis_service_url(redis_url):
|
||||
parsed_url = urlparse(redis_url)
|
||||
if parsed_url.scheme != "redis":
|
||||
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
|
||||
@@ -20,7 +20,7 @@ def parse_redis_sentinel_url(redis_url):
|
||||
|
||||
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_sentinel_url(redis_url)
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
@@ -45,65 +45,14 @@ def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
||||
return []
|
||||
|
||||
|
||||
class AsyncRedisSentinelManager(socketio.AsyncRedisManager):
|
||||
def __init__(
|
||||
self,
|
||||
sentinel_hosts,
|
||||
sentinel_port=26379,
|
||||
redis_port=6379,
|
||||
service="mymaster",
|
||||
db=0,
|
||||
username=None,
|
||||
password=None,
|
||||
channel="socketio",
|
||||
write_only=False,
|
||||
logger=None,
|
||||
redis_options=None,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis Sentinel Manager.
|
||||
This implementation mostly replicates the __init__ of AsyncRedisManager and
|
||||
overrides _redis_connect() with a version that uses Redis Sentinel
|
||||
|
||||
:param sentinel_hosts: List of Sentinel hosts
|
||||
:param sentinel_port: Sentinel Port
|
||||
:param redis_port: Redis Port (currently unsupported by aioredis!)
|
||||
:param service: Master service name in Sentinel
|
||||
:param db: Redis database to use
|
||||
:param username: Redis username (if any) (currently unsupported by aioredis!)
|
||||
:param password: Redis password (if any)
|
||||
:param channel: The channel name on which the server sends and receives
|
||||
notifications. Must be the same in all the servers.
|
||||
:param write_only: If set to ``True``, only initialize to emit events. The
|
||||
default of ``False`` initializes the class for emitting
|
||||
and receiving.
|
||||
:param redis_options: additional keyword arguments to be passed to
|
||||
``aioredis.from_url()``.
|
||||
"""
|
||||
self._sentinels = [(host, sentinel_port) for host in sentinel_hosts]
|
||||
self._redis_port = redis_port
|
||||
self._service = service
|
||||
self._db = db
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._channel = channel
|
||||
self.redis_options = redis_options or {}
|
||||
|
||||
# connect and call grandparent constructor
|
||||
self._redis_connect()
|
||||
super(socketio.AsyncRedisManager, self).__init__(
|
||||
channel=channel, write_only=write_only, logger=logger
|
||||
)
|
||||
|
||||
def _redis_connect(self):
|
||||
"""Establish connections to Redis through Sentinel."""
|
||||
sentinel = aioredis.sentinel.Sentinel(
|
||||
self._sentinels,
|
||||
port=self._redis_port,
|
||||
db=self._db,
|
||||
password=self._password,
|
||||
**self.redis_options,
|
||||
)
|
||||
|
||||
self.redis = sentinel.master_for(self._service)
|
||||
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
|
||||
def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
username = redis_config["username"] or ""
|
||||
password = redis_config["password"] or ""
|
||||
auth_part = ""
|
||||
if username or password:
|
||||
auth_part = f"{username}:{password}@"
|
||||
hosts_part = ",".join(
|
||||
f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
|
||||
)
|
||||
return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"
|
||||
|
||||
@@ -152,6 +152,8 @@ def rag_template(template: str, context: str, query: str):
|
||||
if template.strip() == "":
|
||||
template = DEFAULT_RAG_TEMPLATE
|
||||
|
||||
template = prompt_template(template)
|
||||
|
||||
if "[context]" not in template and "{{CONTEXT}}" not in template:
|
||||
log.debug(
|
||||
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
|
||||
|
||||
Reference in New Issue
Block a user