Merge remote-tracking branch 'origin' into logit_bias

This commit is contained in:
dannyl1u
2025-02-27 23:48:22 -08:00
181 changed files with 10428 additions and 5218 deletions

View File

@@ -9,7 +9,6 @@ from pathlib import Path
from typing import Generic, Optional, TypeVar
from urllib.parse import urlparse
import chromadb
import requests
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
@@ -44,7 +43,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
# Function to run the alembic migrations
def run_migrations():
print("Running migrations")
log.info("Running migrations")
try:
from alembic import command
from alembic.config import Config
@@ -57,7 +56,7 @@ def run_migrations():
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
log.exception(f"Error running migrations: {e}")
run_migrations()
@@ -588,20 +587,6 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path or ".." in path:
log.error(f"Invalid path: {path}")
return
file_path = os.path.join(STATIC_DIR, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists():
@@ -692,12 +677,20 @@ S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
S3_USE_ACCELERATE_ENDPOINT = (
os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true"
)
S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
"GOOGLE_APPLICATION_CREDENTIALS_JSON", None
)
AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None)
AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None)
AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None)
####################################
# File Upload DIR
####################################
@@ -797,6 +790,9 @@ ENABLE_OPENAI_API = PersistentConfig(
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
@@ -1101,7 +1097,7 @@ try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
except Exception as e:
print(f"Error loading WEBUI_BANNERS: {e}")
log.exception(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
@@ -1377,6 +1373,44 @@ Responses from models: {{responses}}"""
# Code Interpreter
####################################
CODE_EXECUTION_ENGINE = PersistentConfig(
"CODE_EXECUTION_ENGINE",
"code_execution.engine",
os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"),
)
CODE_EXECUTION_JUPYTER_URL = PersistentConfig(
"CODE_EXECUTION_JUPYTER_URL",
"code_execution.jupyter.url",
os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""),
)
CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig(
"CODE_EXECUTION_JUPYTER_AUTH",
"code_execution.jupyter.auth",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
)
CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig(
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN",
"code_execution.jupyter.auth_token",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
)
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig(
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD",
"code_execution.jupyter.auth_password",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
)
CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig(
"CODE_EXECUTION_JUPYTER_TIMEOUT",
"code_execution.jupyter.timeout",
int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")),
)
ENABLE_CODE_INTERPRETER = PersistentConfig(
"ENABLE_CODE_INTERPRETER",
"code_interpreter.enable",
@@ -1398,26 +1432,48 @@ CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_URL",
"code_interpreter.jupyter.url",
os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
os.environ.get(
"CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "")
),
)
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH",
"code_interpreter.jupyter.auth",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
os.environ.get(
"CODE_INTERPRETER_JUPYTER_AUTH",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
),
)
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
"code_interpreter.jupyter.auth_token",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
os.environ.get(
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
),
)
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
"code_interpreter.jupyter.auth_password",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
os.environ.get(
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
),
)
CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
"code_interpreter.jupyter.timeout",
int(
os.environ.get(
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"),
)
),
)
@@ -1445,21 +1501,27 @@ VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
# Chroma
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
if VECTOR_DB == "chroma":
import chromadb
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get(
"CHROMA_CLIENT_AUTH_CREDENTIALS", ""
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
# Milvus
@@ -1513,6 +1575,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig(
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
)
ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig(
"ENABLE_ONEDRIVE_INTEGRATION",
"onedrive.enable",
os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true",
)
ONEDRIVE_CLIENT_ID = PersistentConfig(
"ONEDRIVE_CLIENT_ID",
"onedrive.client_id",
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
)
# RAG Content Extraction
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
@@ -1526,6 +1600,26 @@ TIKA_SERVER_URL = PersistentConfig(
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
"DOCUMENT_INTELLIGENCE_ENDPOINT",
"rag.document_intelligence_endpoint",
os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""),
)
DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
"DOCUMENT_INTELLIGENCE_KEY",
"rag.document_intelligence_key",
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
)
BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
"BYPASS_EMBEDDING_AND_RETRIEVAL",
"rag.bypass_embedding_and_retrieval",
os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
)
RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
)
@@ -1541,6 +1635,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
)
RAG_FULL_CONTEXT = PersistentConfig(
"RAG_FULL_CONTEXT",
"rag.full_context",
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
)
RAG_FILE_MAX_COUNT = PersistentConfig(
"RAG_FILE_MAX_COUNT",
"rag.file.max_count",
@@ -1655,7 +1755,7 @@ Respond to the user query using the provided context, incorporating inline citat
- Respond in the same language as the user's query.
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**
- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `<source_id>` tag is explicitly provided in the context.**
- Do not cite if the <source_id> tag is not provided in the context.
- Do not use XML tags in your response.
- Ensure citations are concise and directly related to the information provided.
@@ -1736,6 +1836,12 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL",
"rag.web.search.bypass_embedding_and_retrieval",
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
)
# You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
@@ -1883,10 +1989,34 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
RAG_WEB_LOADER_ENGINE = PersistentConfig(
"RAG_WEB_LOADER_ENGINE",
"rag.web.loader.engine",
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
)
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
"RAG_WEB_SEARCH_TRUST_ENV",
"rag.web.search.trust_env",
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
)
PLAYWRIGHT_WS_URI = PersistentConfig(
"PLAYWRIGHT_WS_URI",
"rag.web.loader.engine.playwright.ws.uri",
os.environ.get("PLAYWRIGHT_WS_URI", None),
)
FIRECRAWL_API_KEY = PersistentConfig(
"FIRECRAWL_API_KEY",
"firecrawl.api_key",
os.environ.get("FIRECRAWL_API_KEY", ""),
)
FIRECRAWL_API_BASE_URL = PersistentConfig(
"FIRECRAWL_API_BASE_URL",
"firecrawl.api_url",
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
)
####################################
@@ -2099,6 +2229,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
"IMAGES_GEMINI_API_BASE_URL",
"image_generation.gemini.api_base_url",
os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
)
IMAGES_GEMINI_API_KEY = PersistentConfig(
"IMAGES_GEMINI_API_KEY",
"image_generation.gemini.api_key",
os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY),
)
IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
@@ -2275,7 +2416,7 @@ LDAP_SEARCH_BASE = PersistentConfig(
LDAP_SEARCH_FILTERS = PersistentConfig(
"LDAP_SEARCH_FILTER",
"ldap.server.search_filter",
os.environ.get("LDAP_SEARCH_FILTER", ""),
os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")),
)
LDAP_USE_TLS = PersistentConfig(

View File

@@ -419,3 +419,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048
# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]

View File

@@ -2,6 +2,7 @@ import logging
import sys
import inspect
import json
import asyncio
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
@@ -76,11 +77,13 @@ async def get_function_models(request):
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:

View File

@@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.utils import logger
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
from open_webui.utils.logger import start_logger
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
@@ -95,12 +98,19 @@ from open_webui.config import (
OLLAMA_API_CONFIGS,
# OpenAI
ENABLE_OPENAI_API,
ONEDRIVE_CLIENT_ID,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
# Code Interpreter
# Code Execution
CODE_EXECUTION_ENGINE,
CODE_EXECUTION_JUPYTER_URL,
CODE_EXECUTION_JUPYTER_AUTH,
CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
CODE_EXECUTION_JUPYTER_TIMEOUT,
ENABLE_CODE_INTERPRETER,
CODE_INTERPRETER_ENGINE,
CODE_INTERPRETER_PROMPT_TEMPLATE,
@@ -108,6 +118,7 @@ from open_webui.config import (
CODE_INTERPRETER_JUPYTER_AUTH,
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
CODE_INTERPRETER_JUPYTER_TIMEOUT,
# Image
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
@@ -126,6 +137,8 @@ from open_webui.config import (
IMAGE_STEPS,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
IMAGES_GEMINI_API_BASE_URL,
IMAGES_GEMINI_API_KEY,
# Audio
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
@@ -140,6 +153,10 @@ from open_webui.config import (
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
PLAYWRIGHT_WS_URI,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
RAG_WEB_LOADER_ENGINE,
WHISPER_MODEL,
DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE,
@@ -147,6 +164,8 @@ from open_webui.config import (
# Retrieval
RAG_TEMPLATE,
DEFAULT_RAG_TEMPLATE,
RAG_FULL_CONTEXT,
BYPASS_EMBEDDING_AND_RETRIEVAL,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@@ -166,6 +185,8 @@ from open_webui.config import (
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
RAG_TOP_K,
RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
@@ -174,6 +195,7 @@ from open_webui.config import (
YOUTUBE_LOADER_PROXY_URL,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE,
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_TRUST_ENV,
@@ -200,11 +222,13 @@ from open_webui.config import (
GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID,
GOOGLE_DRIVE_API_KEY,
ONEDRIVE_CLIENT_ID,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENABLE_GOOGLE_DRIVE_INTEGRATION,
ENABLE_ONEDRIVE_INTEGRATION,
UPLOAD_DIR,
# WebUI
WEBUI_AUTH,
@@ -283,8 +307,11 @@ from open_webui.config import (
reset_config,
)
from open_webui.env import (
AUDIT_EXCLUDED_PATHS,
AUDIT_LOG_LEVEL,
CHANGELOG,
GLOBAL_LOG_LEVEL,
MAX_BODY_LOG_SIZE,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
@@ -369,6 +396,7 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
start_logger()
if RESET_CONFIG_ON_START:
reset_config()
@@ -509,6 +537,9 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
@@ -516,6 +547,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
@@ -543,9 +576,13 @@ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
@@ -569,7 +606,11 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.EMBEDDING_FUNCTION = None
app.state.ef = None
@@ -613,10 +654,19 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
########################################
#
# CODE INTERPRETER
# CODE EXECUTION
#
########################################
app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE
app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL
app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
@@ -629,6 +679,7 @@ app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT
########################################
#
@@ -643,6 +694,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
@@ -844,6 +898,19 @@ app.include_router(
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
try:
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
except ValueError as e:
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
audit_level = AuditLevel.NONE
if audit_level != AuditLevel.NONE:
app.add_middleware(
AuditLoggingMiddleware,
audit_level=audit_level,
excluded_paths=AUDIT_EXCLUDED_PATHS,
max_body_size=MAX_BODY_LOG_SIZE,
)
##################################
#
# Chat Endpoints
@@ -876,7 +943,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
models = await get_all_models(request)
models = await get_all_models(request, user=user)
# Filter out filter pipelines
models = [
@@ -905,7 +972,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
@app.get("/api/models/base")
async def get_base_models(request: Request, user=Depends(get_admin_user)):
models = await get_all_base_models(request)
models = await get_all_base_models(request, user=user)
return {"data": models}
@@ -916,7 +983,7 @@ async def chat_completion(
user=Depends(get_verified_user),
):
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
@@ -952,7 +1019,7 @@ async def chat_completion(
"files": form_data.get("files", None),
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"model": model_info,
"model": model_info.model_dump() if model_info else model,
"direct": model_item.get("direct", False),
**(
{"function_calling": "native"}
@@ -1111,6 +1178,7 @@ async def get_app_config(request: Request):
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
}
if user is not None
else {}
@@ -1120,6 +1188,9 @@ async def get_app_config(request: Request):
{
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"code": {
"engine": app.state.config.CODE_EXECUTION_ENGINE,
},
"audio": {
"tts": {
"engine": app.state.config.TTS_ENGINE,
@@ -1139,6 +1210,7 @@ async def get_app_config(request: Request):
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
"api_key": GOOGLE_DRIVE_API_KEY.value,
},
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
}
if user is not None
else {}

View File

@@ -1,3 +1,4 @@
import logging
import json
import time
import uuid
@@ -5,7 +6,7 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
@@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
# Chat DB Schema
####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
@@ -670,7 +674,7 @@ class ChatTable:
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
print(len(all_chats))
log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
@@ -731,7 +735,7 @@ class ChatTable:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
log.info(f"DB dialect name: {db.bind.dialect.name}")
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
@@ -752,7 +756,7 @@ class ChatTable:
)
all_chats = query.all()
print("all_chats", all_chats)
log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
@@ -810,7 +814,7 @@ class ChatTable:
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
log.info(f"Count of chats for tag '{tag_name}': {count}")
return count

View File

@@ -118,7 +118,7 @@ class FeedbackTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error creating a new feedback: {e}")
return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:

View File

@@ -119,7 +119,7 @@ class FilesTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error inserting a new file: {e}")
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:

View File

@@ -82,7 +82,7 @@ class FolderTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error inserting a new folder: {e}")
return None
def get_folder_by_id_and_user_id(

View File

@@ -105,7 +105,7 @@ class FunctionsTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error creating a new function: {e}")
return None
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
@@ -170,7 +170,7 @@ class FunctionsTable:
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Error getting function valves by id {id}: {e}")
return None
def update_function_valves_by_id(
@@ -202,7 +202,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
@@ -225,7 +227,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:

5
backend/open_webui/models/models.py Normal file → Executable file
View File

@@ -166,7 +166,7 @@ class ModelsTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Failed to insert a new model: {e}")
return None
def get_all_models(self) -> list[ModelModel]:
@@ -246,8 +246,7 @@ class ModelsTable:
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
print(e)
log.exception(f"Failed to update the model by id {id}: {e}")
return None
def delete_model_by_id(self, id: str) -> bool:

View File

@@ -61,7 +61,7 @@ class TagTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error inserting a new tag: {e}")
return None
def get_tag_by_name_and_user_id(

View File

@@ -131,7 +131,7 @@ class ToolsTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error creating a new tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
@@ -175,7 +175,7 @@ class ToolsTable:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Error getting tool valves by id {id}: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
@@ -204,7 +204,9 @@ class ToolsTable:
return user_settings["tools"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
@@ -227,7 +229,9 @@ class ToolsTable:
return user_settings["tools"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:

View File

@@ -4,6 +4,7 @@ import ftfy
import sys
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
BSHTMLLoader,
CSVLoader,
Docx2txtLoader,
@@ -76,6 +77,7 @@ known_source_ext = [
"jsx",
"hs",
"lhs",
"json",
]
@@ -147,6 +149,27 @@ class Loader:
file_path=file_path,
mime_type=file_content_type,
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and (
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type
in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]
)
):
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
)
else:
if file_ext == "pdf":
loader = PyPDFLoader(

View File

@@ -1,13 +1,19 @@
import os
import logging
import torch
import numpy as np
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT:
def __init__(self, name, **kwargs) -> None:
print("ColBERT: Loading model", name)
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
DOCKER = kwargs.get("env") == "docker"

View File

@@ -5,6 +5,7 @@ from typing import Optional, Union
import asyncio
import requests
import hashlib
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@@ -14,8 +15,10 @@ from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.env import (
SRC_LOG_LEVELS,
@@ -80,7 +83,20 @@ def query_doc(
return result
except Exception as e:
print(e)
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
raise e
@@ -137,47 +153,80 @@ def query_doc_with_hybrid_search(
raise e
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
combined_metadatas = []
combined_ids = []
for data in query_results:
combined_distances.extend(data["distances"][0])
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_ids.extend(data["ids"][0])
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
"ids": [combined_ids],
}
return result
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> dict:
# Initialize lists to store combined data
combined = []
seen_hashes = set() # To store unique document hashes
for data in query_results:
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.md5(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in seen_hashes:
seen_hashes.add(doc_hash)
combined.append((distance, document, metadata))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
# Create the output dictionary
result = {
"distances": [sorted_distances],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
}
return result
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection(
@@ -290,6 +339,7 @@ def get_embedding_function(
def get_sources_from_files(
request,
files,
queries,
embedding_function,
@@ -297,21 +347,74 @@ def get_sources_from_files(
reranking_function,
r,
hybrid_search,
full_context=False,
):
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
relevant_contexts = []
for file in files:
if file.get("context") == "full":
context = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
else:
context = None
elif (
file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
@@ -331,42 +434,50 @@ def get_sources_from_files(
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
if full_context:
try:
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)
if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})
sources = []
@@ -463,7 +574,7 @@ def generate_openai_batch_embeddings(
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
log.exception(f"Error generating openai batch embeddings: {e}")
return None
@@ -497,7 +608,7 @@ def generate_ollama_batch_embeddings(
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
log.exception(f"Error generating ollama batch embeddings: {e}")
return None

8
backend/open_webui/retrieval/vector/dbs/chroma.py Normal file → Executable file
View File

@@ -1,4 +1,5 @@
import chromadb
import logging
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
@@ -16,6 +17,10 @@ from open_webui.config import (
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
@@ -102,8 +107,7 @@ class ChromaClient:
}
)
return None
except Exception as e:
print(e)
except:
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@@ -1,7 +1,7 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
@@ -10,6 +10,10 @@ from open_webui.config import (
MILVUS_DB,
MILVUS_TOKEN,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
@@ -168,7 +172,7 @@ class MilvusClient:
try:
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
print("remaining", remaining)
log.info(f"remaining: {remaining}")
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
@@ -195,10 +199,12 @@ class MilvusClient:
if results_count < current_fetch:
break
print(all_results)
log.debug(all_results)
return self._result_to_get_result([all_results])
except Exception as e:
print(e)
log.exception(
f"Error querying collection {collection_name} with limit {limit}: {e}"
)
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@@ -1,4 +1,5 @@
from typing import Optional, List, Dict, Any
import logging
from sqlalchemy import (
cast,
column,
@@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
@@ -82,10 +88,10 @@ class PgvectorClient:
)
)
self.session.commit()
print("Initialization complete.")
log.info("Initialization complete.")
except Exception as e:
self.session.rollback()
print(f"Error during initialization: {e}")
log.exception(f"Error during initialization: {e}")
raise
def check_vector_length(self) -> None:
@@ -150,12 +156,12 @@ class PgvectorClient:
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
print(
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during insert: {e}")
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -184,10 +190,12 @@ class PgvectorClient:
)
self.session.add(new_chunk)
self.session.commit()
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during upsert: {e}")
log.exception(f"Error during upsert: {e}")
raise
def search(
@@ -278,7 +286,7 @@ class PgvectorClient:
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
print(f"Error during search: {e}")
log.exception(f"Error during search: {e}")
return None
def query(
@@ -310,7 +318,7 @@ class PgvectorClient:
metadatas=metadatas,
)
except Exception as e:
print(f"Error during query: {e}")
log.exception(f"Error during query: {e}")
return None
def get(
@@ -334,7 +342,7 @@ class PgvectorClient:
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
print(f"Error during get: {e}")
log.exception(f"Error during get: {e}")
return None
def delete(
@@ -356,22 +364,22 @@ class PgvectorClient:
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
print(f"Deleted {deleted} items from collection '{collection_name}'.")
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during delete: {e}")
log.exception(f"Error during delete: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
print(
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
self.session.rollback()
print(f"Error during reset: {e}")
log.exception(f"Error during reset: {e}")
raise
def close(self) -> None:
@@ -387,9 +395,9 @@ class PgvectorClient:
)
return exists
except Exception as e:
print(f"Error checking collection existence: {e}")
log.exception(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
print(f"Collection '{collection_name}' deleted.")
log.info(f"Collection '{collection_name}' deleted.")

View File

@@ -1,4 +1,5 @@
from typing import Optional
import logging
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
@@ -6,9 +7,13 @@ from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
def __init__(self):
@@ -49,7 +54,7 @@ class QdrantClient:
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
@@ -120,7 +125,7 @@ class QdrantClient:
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
log.exception(f"Error querying a collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@@ -27,8 +27,7 @@ def search_tavily(
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
include_domain = filter_list
response = requests.post(url, include_domain, json=data)
response = requests.post(url, json=data)
response.raise_for_status()
json_response = response.json()

View File

@@ -1,22 +1,38 @@
import socket
import aiohttp
import asyncio
import urllib.parse
import validators
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
from langchain_community.document_loaders import (
WebBaseLoader,
)
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
from open_webui.env import SRC_LOG_LEVELS
import logging
import socket
import ssl
import urllib.parse
import urllib.request
from collections import defaultdict
from datetime import datetime, time, timedelta
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
Literal,
)
import aiohttp
import certifi
import validators
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
PLAYWRIGHT_WS_URI,
RAG_WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -68,6 +84,314 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {"source": url}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", "No description found.")
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
return metadata
def verify_ssl_cert(url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
context = ssl.create_default_context(cafile=certifi.where())
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
s.connect((hostname, 443))
return True
except ssl.SSLError:
return False
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
return False
class SafeFireCrawlLoader(BaseLoader):
def __init__(
self,
web_paths,
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
mode: Literal["crawl", "scrape", "map"] = "crawl",
proxy: Optional[Dict[str, str]] = None,
params: Optional[Dict] = None,
):
"""Concurrent document loader for FireCrawl operations.
Executes multiple FireCrawlLoader instances concurrently using thread pooling
to improve bulk processing efficiency.
Args:
web_paths: List of URLs/paths to process.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
requests_per_second: Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
api_key: API key for FireCrawl service. Defaults to None
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
mode: Operation mode selection:
- 'crawl': Website crawling mode (default)
- 'scrape': Direct page scraping
- 'map': Site map generation
proxy: Proxy override settings for the FireCrawl API.
params: The parameters to pass to the Firecrawl API.
Examples include crawlerOptions.
For more details, visit: https://github.com/mendableai/firecrawl-py
"""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
self.web_paths = web_paths
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.trust_env = trust_env
self.continue_on_failure = continue_on_failure
self.api_key = api_key
self.api_url = api_url
self.mode = mode
self.params = params
def lazy_load(self) -> Iterator[Document]:
"""Load documents concurrently using FireCrawl."""
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
async def alazy_load(self):
"""Async version of lazy_load."""
for url in self.web_paths:
try:
await self._safe_process_url(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
async for document in loader.alazy_load():
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafePlaywrightURLLoader(PlaywrightURLLoader):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes:
web_paths (List[str]): List of URLs to load.
verify_ssl (bool): If True, verify SSL certificates.
trust_env (bool): If True, use proxy settings from environment variables.
requests_per_second (Optional[float]): Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
headless (bool): If True, the browser will run in headless mode.
proxy (dict): Proxy override settings for the Playwright session.
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
"""
def __init__(
self,
web_paths: List[str],
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
headless: bool = True,
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None,
playwright_ws_url: Optional[str] = None,
):
"""Initialize with additional safety parameters and remote browser support."""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
super().__init__(
urls=web_paths,
continue_on_failure=continue_on_failure,
headless=headless if playwright_ws_url is None else False,
remove_selectors=remove_selectors,
proxy=proxy,
)
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.playwright_ws_url = playwright_ws_url
self.trust_env = trust_env
def lazy_load(self) -> Iterator[Document]:
"""Safely load URLs synchronously with support for remote browser."""
from playwright.sync_api import sync_playwright
with sync_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = p.chromium.connect(self.playwright_ws_url)
else:
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
for url in self.urls:
try:
self._safe_process_url_sync(url)
page = browser.new_page()
response = page.goto(url)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = self.evaluator.evaluate(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
browser.close()
async def alazy_load(self) -> AsyncIterator[Document]:
"""Safely load URLs asynchronously with support for remote browser."""
from playwright.async_api import async_playwright
async with async_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = await p.chromium.connect(self.playwright_ws_url)
else:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy
)
for url in self.urls:
try:
await self._safe_process_url(url)
page = await browser.new_page()
response = await page.goto(url)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = await self.evaluator.evaluate_async(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
await browser.close()
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
@@ -143,20 +467,12 @@ class SafeWebBaseLoader(WebBaseLoader):
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
metadata = extract_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
log.exception(e, "Error loading %s", path)
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
@@ -179,6 +495,12 @@ class SafeWebBaseLoader(WebBaseLoader):
return [document async for document in self.alazy_load()]
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
@@ -188,10 +510,29 @@ def get_web_loader(
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
return SafeWebBaseLoader(
web_path=safe_urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
trust_env=trust_env,
web_loader_args = {
"web_paths": safe_urls,
"verify_ssl": verify_ssl,
"requests_per_second": requests_per_second,
"continue_on_failure": True,
"trust_env": trust_env,
}
if PLAYWRIGHT_WS_URI.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
# Create the appropriate WebLoader based on the configuration
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
web_loader = WebLoaderClass(**web_loader_args)
log.debug(
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
return web_loader

View File

@@ -37,6 +37,7 @@ from open_webui.config import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
@@ -70,7 +71,7 @@ from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
log.error(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
@@ -87,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
log.info(f"Converted {file_path} to {output_path}")
def set_faster_whisper_model(model: str, auto_update: bool = False):
@@ -265,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
payload["model"] = request.app.state.config.TTS_MODEL
try:
# print(payload)
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
@@ -323,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
)
try:
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={
@@ -380,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
headers={
@@ -458,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
def transcribe(request: Request, file_path):
print("transcribe", file_path)
log.info(f"transcribe: {file_path}")
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
@@ -670,7 +679,22 @@ def transcription(
def get_available_models(request: Request) -> list[dict]:
available_models = []
if request.app.state.config.TTS_ENGINE == "openai":
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
)
response.raise_for_status()
data = response.json()
available_models = data.get("models", [])
except Exception as e:
log.error(f"Error fetching models from custom endpoint: {str(e)}")
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
else:
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
response = requests.get(
@@ -701,14 +725,37 @@ def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict"""
available_voices = {}
if request.app.state.config.TTS_ENGINE == "openai":
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
)
response.raise_for_status()
data = response.json()
voices_list = data.get("voices", [])
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
except Exception as e:
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
else:
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
available_voices = get_elevenlabs_voices(

View File

@@ -31,10 +31,7 @@ from open_webui.env import (
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response
from open_webui.config import (
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
)
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import (
@@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions
from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars
if ENABLE_LDAP.value:
from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars
router = APIRouter()
@@ -252,14 +251,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if not user:
try:
user_count = Users.get_num_users()
if (
request.app.state.USER_COUNT
and user_count >= request.app.state.USER_COUNT
):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
role = (
"admin"
@@ -423,7 +414,6 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
user_count = Users.get_num_users()
if WEBUI_AUTH:
if (
@@ -434,16 +424,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
if user_count != 0:
if Users.get_num_users() != 0:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
user_count = Users.get_num_users()
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -546,7 +532,8 @@ async def signout(request: Request, response: Response):
if logout_url:
response.delete_cookie("oauth_id_token")
return RedirectResponse(
url=f"{logout_url}?id_token_hint={oauth_id_token}"
headers=response.headers,
url=f"{logout_url}?id_token_hint={oauth_id_token}",
)
else:
raise HTTPException(
@@ -612,7 +599,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email:
admin = Users.get_user_by_email(admin_email)

View File

@@ -70,6 +70,12 @@ async def set_direct_connections_config(
# CodeInterpreterConfig
############################
class CodeInterpreterConfigForm(BaseModel):
CODE_EXECUTION_ENGINE: str
CODE_EXECUTION_JUPYTER_URL: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
@@ -77,11 +83,18 @@ class CodeInterpreterConfigForm(BaseModel):
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
return {
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@@ -89,13 +102,32 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
}
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def set_code_interpreter_config(
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
async def set_code_execution_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
form_data.CODE_EXECUTION_JUPYTER_URL
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
form_data.CODE_EXECUTION_JUPYTER_AUTH
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
)
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
@@ -116,8 +148,17 @@ async def set_code_interpreter_config(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
)
return {
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@@ -125,6 +166,7 @@ async def set_code_interpreter_config(
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
}

View File

@@ -16,6 +16,7 @@ from open_webui.models.files import (
Files,
)
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.routers.audio import transcribe
from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel
@@ -67,7 +68,22 @@ def upload_file(
)
try:
process_file(request, ProcessFileForm(file_id=id), user=user)
if file.content_type in [
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/x-m4a",
]:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path)
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
else:
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@@ -225,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
content_type = file.meta.get("content_type")
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename)
headers = {}
if file.meta.get("content_type") not in [
"application/pdf",
"text/plain",
]:
headers = {
**headers,
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
}
return FileResponse(file_path, headers=headers)
if content_type == "application/pdf" or filename.lower().endswith(
".pdf"
):
headers["Content-Disposition"] = (
f"inline; filename*=UTF-8''{encoded_filename}"
)
content_type = "application/pdf"
elif content_type != "text/plain":
headers["Content-Disposition"] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
return FileResponse(file_path, headers=headers, media_type=content_type)
else:
raise HTTPException(
@@ -266,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
log.info(f"file_path: {file_path}")
return FileResponse(file_path)
else:
raise HTTPException(

View File

@@ -1,4 +1,5 @@
import os
import logging
from pathlib import Path
from typing import Optional
@@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@@ -79,7 +85,7 @@ async def create_new_function(
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
log.exception(f"Failed to create a new function: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -183,7 +189,7 @@ async def update_function_by_id(
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
log.debug(updated)
function = Functions.update_function_by_id(id, updated)
@@ -299,7 +305,7 @@ async def update_function_valves_by_id(
Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Error updating function values by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -388,7 +394,7 @@ async def update_function_user_valves_by_id(
)
return user_valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Error updating function user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),

16
backend/open_webui/routers/groups.py Normal file → Executable file
View File

@@ -1,7 +1,7 @@
import os
from pathlib import Path
from typing import Optional
import logging
from open_webui.models.users import Users
from open_webui.models.groups import (
@@ -14,7 +14,13 @@ from open_webui.models.groups import (
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[GroupResponse])
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try:
group = Groups.insert_new_group(user.id, form_data)
if group:
@@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
)
except Exception as e:
print(e)
log.exception(f"Error creating a new group: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -94,7 +100,7 @@ async def update_group_by_id(
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
)
except Exception as e:
print(e)
log.exception(f"Error updating group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
)
except Exception as e:
print(e)
log.exception(f"Error deleting group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),

View File

@@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
}
@@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
COMFYUI_WORKFLOW_NODES: list[dict]
class GeminiConfigForm(BaseModel):
GEMINI_API_BASE_URL: str
GEMINI_API_KEY: str
class ConfigForm(BaseModel):
enabled: bool
engine: str
@@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
openai: OpenAIConfigForm
automatic1111: Automatic1111ConfigForm
comfyui: ComfyUIConfigForm
gemini: GeminiConfigForm
@router.post("/config/update")
@@ -103,6 +113,11 @@ async def update_config(
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
form_data.gemini.GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
)
@@ -129,6 +144,8 @@ async def update_config(
request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
@@ -155,6 +172,10 @@ async def update_config(
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
}
@@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
r.raise_for_status()
return True
@@ -224,6 +253,12 @@ def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "imagen-3.0-generate-002"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
@@ -299,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
headers = {
@@ -322,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
if model_node_id:
model_list_key = None
print(workflow[model_node_id]["class_type"])
log.info(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][
"required"
]:
@@ -483,6 +522,41 @@ async def image_generations(
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {}
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
model = get_image_model(request)
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["predictions"]:
image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, data, image_data, content_type, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,

View File

@@ -614,7 +614,7 @@ def add_files_to_knowledge_batch(
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
log.info(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)

View File

@@ -14,6 +14,11 @@ from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
from open_webui.models.users import UserModel
from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import (
Depends,
@@ -26,7 +31,7 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, validator
from starlette.background import BackgroundTask
@@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
##########################################
async def send_get_request(url, key=None):
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as response:
return await response.json()
except Exception as e:
@@ -96,6 +115,7 @@ async def send_post_request(
stream: bool = True,
key: Optional[str] = None,
content_type: Optional[str] = None,
user: UserModel = None,
):
r = None
@@ -110,6 +130,16 @@ async def send_post_request(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@@ -191,7 +221,19 @@ async def verify_connection(
try:
async with session.get(
f"{url}/api/version",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as r:
if r.status != 200:
detail = f"HTTP Error: {r.status}"
@@ -254,7 +296,7 @@ async def update_config(
@cached(ttl=3)
async def get_all_models(request: Request):
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
@@ -262,7 +304,7 @@ async def get_all_models(request: Request):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/tags"))
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
@@ -275,7 +317,9 @@ async def get_all_models(request: Request):
key = api_config.get("key", None)
if enable:
request_tasks.append(send_get_request(f"{url}/api/tags", key))
request_tasks.append(
send_get_request(f"{url}/api/tags", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
@@ -360,7 +404,7 @@ async def get_ollama_tags(
models = []
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -370,7 +414,19 @@ async def get_ollama_tags(
r = requests.request(
method="GET",
url=f"{url}/api/tags",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
url, {}
), # Legacy support
).get("key", None),
user=user,
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
@@ -509,6 +566,7 @@ async def pull_model(
url=f"{url}/api/pull",
payload=json.dumps(payload),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -527,7 +585,7 @@ async def push_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@@ -545,6 +603,7 @@ async def push_model(
url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -571,6 +630,7 @@ async def create_model(
url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -588,7 +648,7 @@ async def copy_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.source in models:
@@ -609,6 +669,16 @@ async def copy_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -643,7 +713,7 @@ async def delete_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@@ -665,6 +735,16 @@ async def delete_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@@ -693,7 +773,7 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name not in models:
@@ -714,6 +794,16 @@ async def show_model_info(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -757,7 +847,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -783,6 +873,16 @@ async def embed(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -826,7 +926,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -852,6 +952,16 @@ async def embeddings(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -901,7 +1011,7 @@ async def generate_completion(
user=Depends(get_verified_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -931,15 +1041,29 @@ async def generate_completion(
url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
class ChatMessage(BaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: Optional[list[dict]] = None
images: Optional[list[str]] = None
@validator("content", pre=True)
@classmethod
def check_at_least_one_field(cls, field_value, values, **kwargs):
# Raise an error if both 'content' and 'tool_calls' are None
if field_value is None and (
"tool_calls" not in values or values["tool_calls"] is None
):
raise ValueError(
"At least one of 'content' or 'tool_calls' must be provided"
)
return field_value
class GenerateChatCompletionForm(BaseModel):
model: str
@@ -1047,6 +1171,7 @@ async def generate_chat_completion(
stream=form_data.stream,
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson",
user=user,
)
@@ -1149,6 +1274,7 @@ async def generate_openai_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -1227,6 +1353,7 @@ async def generate_openai_chat_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -1240,7 +1367,7 @@ async def get_openai_models(
models = []
if url_idx is None:
model_list = await get_all_models(request)
model_list = await get_all_models(request, user=user)
models = [
{
"id": model["model"],

View File

@@ -26,6 +26,7 @@ from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.models.users import UserModel
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
@@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
##########################################
async def send_get_request(url, key=None):
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as response:
return await response.json()
except Exception as e:
@@ -84,9 +98,15 @@ def openai_o1_o3_handler(payload):
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
# Fix: o1 and o3 do not support the "system" role directly.
# For older models like "o1-mini" or "o1-preview", use role "user".
# For newer o1/o3 models, replace "system" with "developer".
if payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
model_lower = payload["model"].lower()
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
payload["messages"][0]["role"] = "user"
else:
payload["messages"][0]["role"] = "developer"
return payload
@@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def get_all_models_responses(request: Request) -> list:
async def get_all_models_responses(request: Request, user: UserModel) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
@@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list:
):
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list:
send_get_request(
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@@ -352,13 +375,13 @@ async def get_filtered_models(models, user):
@cached(ttl=3)
async def get_all_models(request: Request) -> dict[str, list]:
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user=user)
def extract_data(response):
if response and "data" in response:
@@ -418,7 +441,7 @@ async def get_models(
}
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
@@ -515,6 +538,16 @@ async def verify_connection(
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200:
@@ -587,7 +620,7 @@ async def generate_chat_completion(
detail="Model not found",
)
await get_all_models(request)
await get_all_models(request, user=user)
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]
@@ -777,7 +810,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if r is not None:
try:
res = await r.json()
print(res)
log.error(res)
if "error" in res:
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:

View File

@@ -101,7 +101,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
if "detail" in res:
raise Exception(response.status, res["detail"])
except Exception as e:
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
return payload
@@ -153,7 +153,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
except Exception:
pass
except Exception as e:
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
return payload
@@ -169,7 +169,7 @@ router = APIRouter()
@router.get("/list")
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user)
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
urlIdxs = [
@@ -196,7 +196,7 @@ async def upload_pipeline(
file: UploadFile = File(...),
user=Depends(get_admin_user),
):
print("upload_pipeline", urlIdx, file.filename)
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
# Check if the uploaded file is a python file
if not (file.filename and file.filename.endswith(".py")):
raise HTTPException(
@@ -231,7 +231,7 @@ async def upload_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
status_code = status.HTTP_404_NOT_FOUND
@@ -282,7 +282,7 @@ async def add_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -327,7 +327,7 @@ async def delete_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -361,7 +361,7 @@ async def get_pipelines(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -400,7 +400,7 @@ async def get_pipeline_valves(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -440,7 +440,7 @@ async def get_pipeline_valves_spec(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -482,7 +482,7 @@ async def update_pipeline_valves(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None

View File

@@ -351,10 +351,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
"content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
"document_intelligence_config": {
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
},
},
"chunk": {
"text_splitter": request.app.state.config.TEXT_SPLITTER,
@@ -371,10 +378,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
},
"web": {
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
@@ -397,6 +406,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
@@ -409,9 +419,15 @@ class FileConfig(BaseModel):
max_count: Optional[int] = None
class DocumentIntelligenceConfigForm(BaseModel):
endpoint: str
key: str
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
class ChunkParamUpdateForm(BaseModel):
@@ -457,12 +473,16 @@ class WebSearchConfig(BaseModel):
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
RAG_FULL_CONTEXT: Optional[bool] = None
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
pdf_extract_images: Optional[bool] = None
enable_google_drive_integration: Optional[bool] = None
enable_onedrive_integration: Optional[bool] = None
file: Optional[FileConfig] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
@@ -480,24 +500,51 @@ async def update_rag_config(
else request.app.state.config.PDF_EXTRACT_IMAGES
)
request.app.state.config.RAG_FULL_CONTEXT = (
form_data.RAG_FULL_CONTEXT
if form_data.RAG_FULL_CONTEXT is not None
else request.app.state.config.RAG_FULL_CONTEXT
)
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
form_data.enable_google_drive_integration
if form_data.enable_google_drive_integration is not None
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
)
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
form_data.enable_onedrive_integration
if form_data.enable_onedrive_integration is not None
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
)
if form_data.file is not None:
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
log.info(
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
)
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
form_data.content_extraction.engine
)
request.app.state.config.TIKA_SERVER_URL = (
form_data.content_extraction.tika_server_url
)
if form_data.content_extraction.document_intelligence_config is not None:
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.content_extraction.document_intelligence_config.endpoint
)
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
form_data.content_extraction.document_intelligence_config.key
)
if form_data.chunk is not None:
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
@@ -512,11 +559,16 @@ async def update_rag_config(
if form_data.web is not None:
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
form_data.web.web_loader_ssl_verification
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.SEARXNG_QUERY_URL = (
form_data.web.search.searxng_query_url
)
@@ -581,6 +633,8 @@ async def update_rag_config(
return {
"status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
"file": {
"max_size": request.app.state.config.FILE_MAX_SIZE,
"max_count": request.app.state.config.FILE_MAX_COUNT,
@@ -588,6 +642,10 @@ async def update_rag_config(
"content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
"document_intelligence_config": {
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
},
},
"chunk": {
"text_splitter": request.app.state.config.TEXT_SPLITTER,
@@ -600,7 +658,8 @@ async def update_rag_config(
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
@@ -863,7 +922,12 @@ def process_file(
# Update the content in the file
# Usage: /files/{file_id}/data/content/update
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
try:
# /files/{file_id}/data/content/update
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
except:
# Audio file upload pipeline
pass
docs = [
Document(
@@ -920,6 +984,8 @@ def process_file(
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
)
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
@@ -962,36 +1028,45 @@ def process_file(
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
try:
result = save_docs_to_vector_db(
request,
docs=docs,
collection_name=collection_name,
metadata={
"file_id": file.id,
"name": file.filename,
"hash": hash,
},
add=(True if form_data.collection_name else False),
user=user,
)
if result:
Files.update_file_metadata_by_id(
file.id,
{
"collection_name": collection_name,
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
try:
result = save_docs_to_vector_db(
request,
docs=docs,
collection_name=collection_name,
metadata={
"file_id": file.id,
"name": file.filename,
"hash": hash,
},
add=(True if form_data.collection_name else False),
user=user,
)
return {
"status": True,
"collection_name": collection_name,
"filename": file.filename,
"content": text_content,
}
except Exception as e:
raise e
if result:
Files.update_file_metadata_by_id(
file.id,
{
"collection_name": collection_name,
},
)
return {
"status": True,
"collection_name": collection_name,
"filename": file.filename,
"content": text_content,
}
except Exception as e:
raise e
else:
return {
"status": True,
"collection_name": None,
"filename": file.filename,
"content": text_content,
}
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
@@ -1262,6 +1337,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.TAVILY_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
@@ -1349,21 +1425,37 @@ async def process_web_search(
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload()
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user,
)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
"loaded_count": len(docs),
}
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
return {
"status": True,
"collection_name": None,
"filenames": urls,
"docs": [
{
"content": doc.page_content,
"metadata": doc.metadata,
}
for doc in docs
],
"loaded_count": len(docs),
}
else:
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user,
)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
"loaded_count": len(docs),
}
except Exception as e:
log.exception(e)
raise HTTPException(
@@ -1520,11 +1612,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
log.exception(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
log.warning(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
return True

View File

@@ -20,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.task import get_task_model_id
from open_webui.config import (
@@ -208,7 +212,7 @@ async def generate_title(
"stream": False,
**(
{"max_tokens": 1000}
if models[task_model_id]["owned_by"] == "ollama"
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 1000,
}
@@ -221,6 +225,12 @@ async def generate_title(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -290,6 +300,12 @@ async def generate_chat_tags(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -356,6 +372,12 @@ async def generate_image_prompt(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -433,6 +455,12 @@ async def generate_queries(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -514,6 +542,12 @@ async def generate_autocompletion(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -571,7 +605,7 @@ async def generate_emoji(
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 4,
}
@@ -584,6 +618,12 @@ async def generate_emoji(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -644,6 +684,12 @@ async def generate_moa_response(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:

View File

@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Optional
@@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@@ -111,7 +116,7 @@ async def create_new_tools(
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
)
except Exception as e:
print(e)
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
@@ -193,7 +198,7 @@ async def update_tools_by_id(
"specs": specs,
}
print(updated)
log.debug(updated)
tools = Tools.update_tool_by_id(id, updated)
if tools:
@@ -343,7 +348,7 @@ async def update_tools_valves_by_id(
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
@@ -421,7 +426,7 @@ async def update_tools_user_valves_by_id(
)
return user_valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),

View File

@@ -1,48 +1,84 @@
import black
import logging
import markdown
from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.auth import get_admin_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@router.get("/gravatar")
async def get_gravatar(
email: str,
):
async def get_gravatar(email: str, user=Depends(get_verified_user)):
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
class CodeForm(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
return {"code": form_data.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/code/execute")
async def execute_code(
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
):
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
output = await execute_code_jupyter(
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
form_data.code,
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
else None
),
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
else None
),
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
)
return output
else:
raise HTTPException(
status_code=400,
detail="Code execution engine not supported",
)
class MarkdownForm(BaseModel):
md: str
@router.post("/markdown")
async def get_html_from_markdown(
form_data: MarkdownForm,
form_data: MarkdownForm, user=Depends(get_verified_user)
):
return {"html": markdown.markdown(form_data.md)}
@@ -54,7 +90,7 @@ class ChatForm(BaseModel):
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatTitleMessagesForm,
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
):
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
@@ -65,7 +101,7 @@ async def download_chat_as_pdf(
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
except Exception as e:
print(e)
log.exception(f"Error generating PDF: {e}")
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,10 +1,12 @@
import os
import shutil
import json
import logging
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from open_webui.config import (
S3_ACCESS_KEY_ID,
@@ -13,14 +15,27 @@ from open_webui.config import (
S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
S3_USE_ACCELERATE_ENDPOINT,
S3_ADDRESSING_STYLE,
GCS_BUCKET_NAME,
GOOGLE_APPLICATION_CREDENTIALS_JSON,
AZURE_STORAGE_ENDPOINT,
AZURE_STORAGE_CONTAINER_NAME,
AZURE_STORAGE_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR,
)
from google.cloud import storage
from google.cloud.exceptions import GoogleCloudError, NotFound
from open_webui.constants import ERROR_MESSAGES
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.core.exceptions import ResourceNotFoundError
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
class StorageProvider(ABC):
@@ -65,7 +80,7 @@ class LocalStorageProvider(StorageProvider):
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
log.warning(f"File {file_path} not found in local storage.")
@staticmethod
def delete_all_files() -> None:
@@ -79,9 +94,9 @@ class LocalStorageProvider(StorageProvider):
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
log.exception(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
log.warning(f"Directory {UPLOAD_DIR} not found in local storage.")
class S3StorageProvider(StorageProvider):
@@ -92,6 +107,12 @@ class S3StorageProvider(StorageProvider):
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(
s3={
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
"addressing_style": S3_ADDRESSING_STYLE,
},
),
)
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
@@ -221,6 +242,74 @@ class GCSStorageProvider(StorageProvider):
LocalStorageProvider.delete_all_files()
class AzureStorageProvider(StorageProvider):
def __init__(self):
self.endpoint = AZURE_STORAGE_ENDPOINT
self.container_name = AZURE_STORAGE_CONTAINER_NAME
storage_key = AZURE_STORAGE_KEY
if storage_key:
# Configure using the Azure Storage Account Endpoint and Key
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=storage_key
)
else:
# Configure using the Azure Storage Account Endpoint and DefaultAzureCredential
# If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=DefaultAzureCredential()
)
self.container_client = self.blob_service_client.get_container_client(
self.container_name
)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to Azure Blob Storage."""
contents, file_path = LocalStorageProvider.upload_file(file, filename)
try:
blob_client = self.container_client.get_blob_client(filename)
blob_client.upload_blob(contents, overwrite=True)
return contents, f"{self.endpoint}/{self.container_name}/{filename}"
except Exception as e:
raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}")
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from Azure Blob Storage."""
try:
filename = file_path.split("/")[-1]
local_file_path = f"{UPLOAD_DIR}/{filename}"
blob_client = self.container_client.get_blob_client(filename)
with open(local_file_path, "wb") as download_file:
download_file.write(blob_client.download_blob().readall())
return local_file_path
except ResourceNotFoundError as e:
raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from Azure Blob Storage."""
try:
filename = file_path.split("/")[-1]
blob_client = self.container_client.get_blob_client(filename)
blob_client.delete_blob()
except ResourceNotFoundError as e:
raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}")
# Always delete from local storage
LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None:
"""Handles deletion of all files from Azure Blob Storage."""
try:
blobs = self.container_client.list_blobs()
for blob in blobs:
self.container_client.delete_blob(blob.name)
except Exception as e:
raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}")
# Always delete from local storage
LocalStorageProvider.delete_all_files()
def get_storage_provider(storage_provider: str):
if storage_provider == "local":
Storage = LocalStorageProvider()
@@ -228,6 +317,8 @@ def get_storage_provider(storage_provider: str):
Storage = S3StorageProvider()
elif storage_provider == "gcs":
Storage = GCSStorageProvider()
elif storage_provider == "azure":
Storage = AzureStorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
return Storage

View File

@@ -7,6 +7,8 @@ from moto import mock_aws
from open_webui.storage import provider
from gcp_storage_emulator.server import create_server
from google.cloud import storage
from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
from unittest.mock import MagicMock
def mock_upload_dir(monkeypatch, tmp_path):
@@ -22,6 +24,7 @@ def test_imports():
provider.LocalStorageProvider
provider.S3StorageProvider
provider.GCSStorageProvider
provider.AzureStorageProvider
provider.Storage
@@ -32,6 +35,8 @@ def test_get_storage_provider():
assert isinstance(Storage, provider.S3StorageProvider)
Storage = provider.get_storage_provider("gcs")
assert isinstance(Storage, provider.GCSStorageProvider)
Storage = provider.get_storage_provider("azure")
assert isinstance(Storage, provider.AzureStorageProvider)
with pytest.raises(RuntimeError):
provider.get_storage_provider("invalid")
@@ -48,6 +53,7 @@ def test_class_instantiation():
provider.LocalStorageProvider()
provider.S3StorageProvider()
provider.GCSStorageProvider()
provider.AzureStorageProvider()
class TestLocalStorageProvider:
@@ -272,3 +278,147 @@ class TestGCSStorageProvider:
assert not (upload_dir / self.filename_extra).exists()
assert self.Storage.bucket.get_blob(self.filename) == None
assert self.Storage.bucket.get_blob(self.filename_extra) == None
class TestAzureStorageProvider:
def __init__(self):
super().__init__()
@pytest.fixture(scope="class")
def setup_storage(self, monkeypatch):
# Create mock Blob Service Client and related clients
mock_blob_service_client = MagicMock()
mock_container_client = MagicMock()
mock_blob_client = MagicMock()
# Set up return values for the mock
mock_blob_service_client.get_container_client.return_value = (
mock_container_client
)
mock_container_client.get_blob_client.return_value = mock_blob_client
# Monkeypatch the Azure classes to return our mocks
monkeypatch.setattr(
azure.storage.blob,
"BlobServiceClient",
lambda *args, **kwargs: mock_blob_service_client,
)
monkeypatch.setattr(
azure.storage.blob,
"ContainerClient",
lambda *args, **kwargs: mock_container_client,
)
monkeypatch.setattr(
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
)
self.Storage = provider.AzureStorageProvider()
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
self.Storage.container_name = "my-container"
self.file_content = b"test content"
self.filename = "test.txt"
self.filename_extra = "test_extra.txt"
self.file_bytesio_empty = io.BytesIO()
# Apply mocks to the Storage instance
self.Storage.blob_service_client = mock_blob_service_client
self.Storage.container_client = mock_container_client
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# Simulate an error when container does not exist
self.Storage.container_client.get_blob_client.side_effect = Exception(
"Container does not exist"
)
with pytest.raises(Exception):
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Reset side effect and create container
self.Storage.container_client.get_blob_client.side_effect = None
self.Storage.create_container()
contents, azure_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
# Assertions
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
self.file_content, overwrite=True
)
assert contents == self.file_content
assert (
azure_file_path
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock upload behavior
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Mock blob download behavior
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
self.file_content
)
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
file_path = self.Storage.get_file(file_url)
assert file_path == str(upload_dir / self.filename)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock file upload
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Mock deletion
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
self.Storage.delete_file(file_url)
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
assert not (upload_dir / self.filename).exists()
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock file uploads
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
# Mock listing and deletion behavior
self.Storage.container_client.list_blobs.return_value = [
{"name": self.filename},
{"name": self.filename_extra},
]
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
self.Storage.delete_all_files()
self.Storage.container_client.list_blobs.assert_called_once()
self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
def test_get_file_not_found(self, monkeypatch):
self.Storage.create_container()
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
# Mock behavior to raise an error for missing blobs
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
Exception("Blob not found")
)
with pytest.raises(Exception, match="Blob not found"):
self.Storage.get_file(file_url)

View File

@@ -0,0 +1,249 @@
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from enum import Enum
import re
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
MutableMapping,
Optional,
cast,
)
import uuid
from asgiref.typing import (
ASGI3Application,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendCallable,
ASGISendEvent,
Scope as ASGIScope,
)
from loguru import logger
from starlette.requests import Request
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
from open_webui.models.users import UserModel
if TYPE_CHECKING:
from loguru import Logger
@dataclass(frozen=True)
class AuditLogEntry:
# `Metadata` audit level properties
id: str
user: dict[str, Any]
audit_level: str
verb: str
request_uri: str
user_agent: Optional[str] = None
source_ip: Optional[str] = None
# `Request` audit level properties
request_object: Any = None
# `Request Response` level
response_object: Any = None
response_status_code: Optional[int] = None
class AuditLevel(str, Enum):
NONE = "NONE"
METADATA = "METADATA"
REQUEST = "REQUEST"
REQUEST_RESPONSE = "REQUEST_RESPONSE"
class AuditLogger:
"""
A helper class that encapsulates audit logging functionality. It uses Logurus logger with an auditable binding to ensure that audit log entries are filtered correctly.
Parameters:
logger (Logger): An instance of Logurus logger.
"""
def __init__(self, logger: "Logger"):
self.logger = logger.bind(auditable=True)
def write(
self,
audit_entry: AuditLogEntry,
*,
log_level: str = "INFO",
extra: Optional[dict] = None,
):
entry = asdict(audit_entry)
if extra:
entry["extra"] = extra
self.logger.log(
log_level,
"",
**entry,
)
class AuditContext:
"""
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
Attributes:
request_body (bytearray): Accumulated request payload.
response_body (bytearray): Accumulated response payload.
max_body_size (int): Maximum number of bytes to capture.
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
"""
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
self.request_body = bytearray()
self.response_body = bytearray()
self.max_body_size = max_body_size
self.metadata: Dict[str, Any] = {}
def add_request_chunk(self, chunk: bytes):
if len(self.request_body) < self.max_body_size:
self.request_body.extend(
chunk[: self.max_body_size - len(self.request_body)]
)
def add_response_chunk(self, chunk: bytes):
if len(self.response_body) < self.max_body_size:
self.response_body.extend(
chunk[: self.max_body_size - len(self.response_body)]
)
class AuditLoggingMiddleware:
"""
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
"""
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
def __init__(
self,
app: ASGI3Application,
*,
excluded_paths: Optional[list[str]] = None,
max_body_size: int = MAX_BODY_LOG_SIZE,
audit_level: AuditLevel = AuditLevel.NONE,
) -> None:
self.app = app
self.audit_logger = AuditLogger(logger)
self.excluded_paths = excluded_paths or []
self.max_body_size = max_body_size
self.audit_level = audit_level
async def __call__(
self,
scope: ASGIScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope=cast(MutableMapping, scope))
if self._should_skip_auditing(request):
return await self.app(scope, receive, send)
async with self._audit_context(request) as context:
async def send_wrapper(message: ASGISendEvent) -> None:
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
await self._capture_response(message, context)
await send(message)
original_receive = receive
async def receive_wrapper() -> ASGIReceiveEvent:
nonlocal original_receive
message = await original_receive()
if self.audit_level in (
AuditLevel.REQUEST,
AuditLevel.REQUEST_RESPONSE,
):
await self._capture_request(message, context)
return message
await self.app(scope, receive_wrapper, send_wrapper)
@asynccontextmanager
async def _audit_context(
self, request: Request
) -> AsyncGenerator[AuditContext, None]:
"""
async context manager that ensures that an audit log entry is recorded after the request is processed.
"""
context = AuditContext()
try:
yield context
finally:
await self._log_audit_entry(request, context)
async def _get_authenticated_user(self, request: Request) -> UserModel:
auth_header = request.headers.get("Authorization")
assert auth_header
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
return user
def _should_skip_auditing(self, request: Request) -> bool:
if (
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
or AUDIT_LOG_LEVEL == "NONE"
or not request.headers.get("authorization")
):
return True
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
pattern = re.compile(
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
)
if pattern.match(request.url.path):
return True
return False
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
if message["type"] == "http.request":
body = message.get("body", b"")
context.add_request_chunk(body)
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
if message["type"] == "http.response.start":
context.metadata["response_status_code"] = message["status"]
elif message["type"] == "http.response.body":
body = message.get("body", b"")
context.add_response_chunk(body)
async def _log_audit_entry(self, request: Request, context: AuditContext):
try:
user = await self._get_authenticated_user(request)
entry = AuditLogEntry(
id=str(uuid.uuid4()),
user=user.model_dump(include={"id", "name", "email", "role"}),
audit_level=self.audit_level.value,
verb=request.method,
request_uri=str(request.url),
response_status_code=context.metadata.get("response_status_code", None),
source_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
request_object=context.request_body.decode("utf-8", errors="replace"),
response_object=context.response_body.decode("utf-8", errors="replace"),
)
self.audit_logger.write(entry)
except Exception as e:
logger.error(f"Failed to log audit entry: {str(e)}")

View File

@@ -5,6 +5,7 @@ import base64
import hmac
import hashlib
import requests
import os
from datetime import UTC, datetime, timedelta
@@ -13,15 +14,22 @@ from typing import Optional, Union, List, Dict
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import override_static
from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
from open_webui.env import (
WEBUI_SECRET_KEY,
TRUSTED_SIGNATURE_KEY,
STATIC_DIR,
SRC_LOG_LEVELS,
)
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
logging.getLogger("passlib").setLevel(logging.ERROR)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256"
@@ -47,6 +55,19 @@ def verify_signature(payload: str, signature: str) -> bool:
return False
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path or ".." in path:
log.error(f"Invalid path: {path}")
return
file_path = os.path.join(STATIC_DIR, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
def get_license_data(app, key):
if key:
try:
@@ -69,11 +90,11 @@ def get_license_data(app, key):
return True
else:
print(
log.error(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
except Exception as ex:
print(f"License: Uncaught Exception: {ex}")
log.exception(f"License: Uncaught Exception: {ex}")
return False
@@ -129,6 +150,7 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user(
request: Request,
background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
token = None
@@ -181,7 +203,10 @@ def get_current_user(
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
# Refresh the user's last active timestamp asynchronously
# to prevent blocking the request
if background_tasks:
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
return user
else:
raise HTTPException(

View File

@@ -66,7 +66,7 @@ async def generate_direct_chat_completion(
user: Any,
models: dict,
):
print("generate_direct_chat_completion")
log.info("generate_direct_chat_completion")
metadata = form_data.pop("metadata", {})
@@ -103,7 +103,7 @@ async def generate_direct_chat_completion(
}
)
print("res", res)
log.info(f"res: {res}")
if res.get("status", False):
# Define a generator to stream responses
@@ -200,7 +200,7 @@ async def generate_chat_completion(
except Exception as e:
raise e
if model["owned_by"] == "arena":
if model.get("owned_by") == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
@@ -253,7 +253,7 @@ async def generate_chat_completion(
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
if model.get("owned_by") == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
@@ -285,7 +285,7 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
@@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
raise Exception(f"Action not found: {action_id}")
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
@@ -432,7 +432,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
)
)
except Exception as e:
print(e)
log.exception(f"Failed to get user values: {e}")
params = {**params, "__user__": __user__}

View File

@@ -1,6 +1,12 @@
import inspect
import logging
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.models.functions import Functions
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_sorted_filter_ids(model):
@@ -61,7 +67,12 @@ async def process_filter_functions(
try:
# Prepare parameters
sig = inspect.signature(handler)
params = {"body": form_data} | {
params = {"body": form_data}
if filter_type == "stream":
params = {"event": form_data}
params = params | {
k: v
for k, v in {
**extra_params,
@@ -80,7 +91,7 @@ async def process_filter_functions(
)
)
except Exception as e:
print(e)
log.exception(f"Failed to get user values: {e}")
# Execute handler
if inspect.iscoroutinefunction(handler):
@@ -89,7 +100,7 @@ async def process_filter_functions(
form_data = handler(**params)
except Exception as e:
print(f"Error in {filter_type} handler {filter_id}: {e}")
log.exception(f"Error in {filter_type} handler {filter_id}: {e}")
raise e
# Handle file cleanup for inlet

View File

@@ -0,0 +1,140 @@
import json
import logging
import sys
from typing import TYPE_CHECKING
from loguru import logger
from open_webui.env import (
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH,
GLOBAL_LOG_LEVEL,
)
if TYPE_CHECKING:
from loguru import Record
def stdout_format(record: "Record") -> str:
"""
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
Parameters:
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
Returns:
str: A formatted log string intended for stdout.
"""
record["extra"]["extra_json"] = json.dumps(record["extra"])
return (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level> - {extra[extra_json]}"
"\n{exception}"
)
class InterceptHandler(logging.Handler):
"""
Intercepts log records from Python's standard logging module
and redirects them to Loguru's logger.
"""
def emit(self, record):
"""
Called by the standard logging module for each log event.
It transforms the standard `LogRecord` into a format compatible with Loguru
and passes it to Loguru's logger.
"""
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
frame, depth = sys._getframe(6), 6
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
def file_format(record: "Record"):
"""
Formats audit log records into a structured JSON string for file output.
Parameters:
record (Record): A Loguru record containing extra audit data.
Returns:
str: A JSON-formatted string representing the audit data.
"""
audit_data = {
"id": record["extra"].get("id", ""),
"timestamp": int(record["time"].timestamp()),
"user": record["extra"].get("user", dict()),
"audit_level": record["extra"].get("audit_level", ""),
"verb": record["extra"].get("verb", ""),
"request_uri": record["extra"].get("request_uri", ""),
"response_status_code": record["extra"].get("response_status_code", 0),
"source_ip": record["extra"].get("source_ip", ""),
"user_agent": record["extra"].get("user_agent", ""),
"request_object": record["extra"].get("request_object", b""),
"response_object": record["extra"].get("response_object", b""),
"extra": record["extra"].get("extra", {}),
}
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
return "{extra[file_extra]}\n"
def start_logger():
"""
Initializes and configures Loguru's logger with distinct handlers:
A console (stdout) handler for general log messages (excluding those marked as auditable).
An optional file handler for audit logs if audit logging is enabled.
Additionally, this function reconfigures Pythons standard logging to route through Loguru and adjusts logging levels for Uvicorn.
Parameters:
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
"""
logger.remove()
logger.add(
sys.stdout,
level=GLOBAL_LOG_LEVEL,
format=stdout_format,
filter=lambda record: "auditable" not in record["extra"],
)
if AUDIT_LOG_LEVEL != "NONE":
try:
logger.add(
AUDIT_LOGS_FILE_PATH,
level="INFO",
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
compression="zip",
format=file_format,
filter=lambda record: record["extra"].get("auditable") is True,
)
except Exception as e:
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
for uvicorn_logger_name in ["uvicorn.access"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")

View File

@@ -322,78 +322,95 @@ async def chat_web_search_handler(
)
return form_data
searchQuery = queries[0]
all_results = []
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
"done": False,
},
}
)
try:
results = await process_web_search(
request,
SearchForm(
**{
for searchQuery in queries:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
}
),
user,
"done": False,
},
}
)
if results:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
try:
results = await process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
"urls": results["filenames"],
"done": True,
},
}
}
),
user=user,
)
files = form_data.get("files", [])
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
form_data["files"] = files
else:
if results:
all_results.append(results)
files = form_data.get("files", [])
if results.get("collection_name"):
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search",
"urls": results["filenames"],
}
)
elif results.get("docs"):
files.append(
{
"docs": results.get("docs", []),
"name": searchQuery,
"type": "web_search",
"urls": results["filenames"],
}
)
form_data["files"] = files
except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"done": True,
"error": True,
},
}
)
except Exception as e:
log.exception(e)
if all_results:
urls = []
for results in all_results:
if "filenames" in results:
urls.extend(results["filenames"])
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"description": "Searched {{count}} sites",
"urls": urls,
"done": True,
},
}
)
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"done": True,
"error": True,
},
@@ -503,6 +520,7 @@ async def chat_completion_files_handler(
sources = []
if files := body.get("metadata", {}).get("files", None):
queries = []
try:
queries_response = await generate_queries(
request,
@@ -528,8 +546,8 @@ async def chat_completion_files_handler(
queries_response = {"queries": [queries_response]}
queries = queries_response.get("queries", [])
except Exception as e:
queries = []
except:
pass
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
@@ -541,6 +559,7 @@ async def chat_completion_files_handler(
sources = await loop.run_in_executor(
executor,
lambda: get_sources_from_files(
request=request,
files=files,
queries=queries,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
@@ -550,9 +569,9 @@ async def chat_completion_files_handler(
reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
),
)
except Exception as e:
log.exception(e)
@@ -728,6 +747,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates
if files:
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
@@ -785,8 +805,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
if len(sources) > 0:
context_string = ""
for source_idx, source in enumerate(sources):
source_id = source.get("source", {}).get("name", "")
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
@@ -806,7 +824,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
if model.get("owned_by") == "ollama":
form_data["messages"] = prepend_to_first_user_message_content(
rag_template(
request.app.state.config.RAG_TEMPLATE, context_string, prompt
@@ -1038,6 +1056,21 @@ async def process_chat_response(
):
return response
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
"__model__": metadata.get("model"),
}
filter_ids = get_sorted_filter_ids(form_data.get("model"))
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@@ -1117,12 +1150,12 @@ async def process_chat_response(
if reasoning_duration is not None:
if raw:
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
else:
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
else:
if raw:
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
else:
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
@@ -1218,9 +1251,9 @@ async def process_chat_response(
return attributes
if content_blocks[-1]["type"] == "text":
for tag in tags:
for start_tag, end_tag in tags:
# Match start tag e.g., <tag> or <tag attr="value">
start_tag_pattern = rf"<{tag}(\s.*?)?>"
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>"
match = re.search(start_tag_pattern, content)
if match:
attr_content = (
@@ -1253,7 +1286,8 @@ async def process_chat_response(
content_blocks.append(
{
"type": content_type,
"tag": tag,
"start_tag": start_tag,
"end_tag": end_tag,
"attributes": attributes,
"content": "",
"started_at": time.time(),
@@ -1265,9 +1299,10 @@ async def process_chat_response(
break
elif content_blocks[-1]["type"] == content_type:
tag = content_blocks[-1]["tag"]
start_tag = content_blocks[-1]["start_tag"]
end_tag = content_blocks[-1]["end_tag"]
# Match end tag e.g., </tag>
end_tag_pattern = rf"</{tag}>"
end_tag_pattern = rf"<{re.escape(end_tag)}>"
# Check if the content has the end tag
if re.search(end_tag_pattern, content):
@@ -1275,7 +1310,7 @@ async def process_chat_response(
block_content = content_blocks[-1]["content"]
# Strip start and end tags from the content
start_tag_pattern = rf"<{tag}(.*?)>"
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
block_content = re.sub(
start_tag_pattern, "", block_content
).strip()
@@ -1340,7 +1375,7 @@ async def process_chat_response(
# Clean processed content
content = re.sub(
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
"",
content,
flags=re.DOTALL,
@@ -1353,7 +1388,22 @@ async def process_chat_response(
)
tool_calls = []
content = message.get("content", "") if message else ""
last_assistant_message = None
try:
if form_data["messages"][-1]["role"] == "assistant":
last_assistant_message = get_last_assistant_message(
form_data["messages"]
)
except Exception as e:
pass
content = (
message.get("content", "")
if message
else last_assistant_message if last_assistant_message else ""
)
content_blocks = [
{
"type": "text",
@@ -1363,19 +1413,24 @@ async def process_chat_response(
# We might want to disable this by default
DETECT_REASONING = True
DETECT_SOLUTION = True
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
"code_interpreter", False
)
reasoning_tags = [
"think",
"thinking",
"reason",
"reasoning",
"thought",
"Thought",
("think", "/think"),
("thinking", "/thinking"),
("reason", "/reason"),
("reasoning", "/reasoning"),
("thought", "/thought"),
("Thought", "/Thought"),
("|begin_of_thought|", "|end_of_thought|"),
]
code_interpreter_tags = ["code_interpreter"]
code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
try:
for event in events:
@@ -1419,119 +1474,154 @@ async def process_chat_response(
try:
data = json.loads(data)
if "selected_model_id" in data:
model_id = data["selected_model_id"]
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": model_id,
},
)
else:
choices = data.get("choices", [])
if not choices:
continue
data, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get("index")
if tool_call_index is not None:
if (
len(response_tool_calls)
<= tool_call_index
):
response_tool_calls.append(
delta_tool_call
)
else:
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = delta_tool_call.get(
"function", {}
).get("arguments")
if delta_name:
response_tool_calls[
tool_call_index
]["function"]["name"] += delta_name
if delta_arguments:
response_tool_calls[
tool_call_index
]["function"][
"arguments"
] += delta_arguments
value = delta.get("content")
if value:
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
content_blocks[-1]["content"] = (
content_blocks[-1]["content"] + value
if data:
if "selected_model_id" in data:
model_id = data["selected_model_id"]
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": model_id,
},
)
if DETECT_REASONING:
content, content_blocks, _ = (
tag_content_handler(
"reasoning",
reasoning_tags,
content,
content_blocks,
else:
choices = data.get("choices", [])
if not choices:
usage = data.get("usage", {})
if usage:
await event_emitter(
{
"type": "chat:completion",
"data": {
"usage": usage,
},
}
)
continue
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get(
"index"
)
if tool_call_index is not None:
if (
len(response_tool_calls)
<= tool_call_index
):
response_tool_calls.append(
delta_tool_call
)
else:
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = (
delta_tool_call.get(
"function", {}
).get("arguments")
)
if delta_name:
response_tool_calls[
tool_call_index
]["function"][
"name"
] += delta_name
if delta_arguments:
response_tool_calls[
tool_call_index
]["function"][
"arguments"
] += delta_arguments
value = delta.get("content")
if value:
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
content_blocks[-1]["content"] = (
content_blocks[-1]["content"] + value
)
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
"code_interpreter",
code_interpreter_tags,
content,
content_blocks,
if DETECT_REASONING:
content, content_blocks, _ = (
tag_content_handler(
"reasoning",
reasoning_tags,
content,
content_blocks,
)
)
)
if end:
break
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
"code_interpreter",
code_interpreter_tags,
content,
content_blocks,
)
)
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
if end:
break
if DETECT_SOLUTION:
content, content_blocks, _ = (
tag_content_handler(
"solution",
solution_tags,
content,
content_blocks,
)
)
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": serialize_content_blocks(
content_blocks
),
},
)
else:
data = {
"content": serialize_content_blocks(
content_blocks
),
},
)
else:
data = {
"content": serialize_content_blocks(
content_blocks
),
}
}
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
except Exception as e:
done = "data: [DONE]" in line
if done:
@@ -1736,6 +1826,7 @@ async def process_chat_response(
== "password"
else None
),
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
)
else:
output = {
@@ -1829,7 +1920,10 @@ async def process_chat_response(
}
)
print(content_blocks, serialize_content_blocks(content_blocks))
log.info(f"content_blocks={content_blocks}")
log.info(
f"serialize_content_blocks={serialize_content_blocks(content_blocks)}"
)
try:
res = await generate_chat_completion(
@@ -1900,7 +1994,7 @@ async def process_chat_response(
await background_tasks_handler()
except asyncio.CancelledError:
print("Task was cancelled!")
log.warning("Task was cancelled!")
await event_emitter({"type": "task-cancelled"})
if not ENABLE_REALTIME_CHAT_SAVE:
@@ -1921,17 +2015,34 @@ async def process_chat_response(
return {"status": True, "task_id": task_id}
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
yield wrap_item(json.dumps(event))
event, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=event,
extra_params=extra_params,
)
if event:
yield wrap_item(json.dumps(event))
async for data in original_generator:
yield data
data, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
if data:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, events),

View File

@@ -2,6 +2,7 @@ import hashlib
import re
import time
import uuid
import logging
from datetime import timedelta
from pathlib import Path
from typing import Callable, Optional
@@ -9,6 +10,10 @@ import json
import collections.abc
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def deep_update(d, u):
@@ -413,7 +418,7 @@ def parse_ollama_modelfile(model_text):
elif param_type is bool:
value = value.lower() == "true"
except Exception as e:
print(e)
log.exception(f"Failed to parse parameter {param}: {e}")
continue
data["params"][param] = value

View File

@@ -22,6 +22,7 @@ from open_webui.config import (
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.models.users import UserModel
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def get_all_base_models(request: Request):
async def get_all_base_models(request: Request, user: UserModel = None):
function_models = []
openai_models = []
ollama_models = []
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request)
openai_models = await openai.get_all_models(request, user=user)
openai_models = openai_models["data"]
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request)
ollama_models = await ollama.get_all_models(request, user=user)
ollama_models = [
{
"id": model["model"],
@@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
return models
async def get_all_models(request):
models = await get_all_base_models(request)
async def get_all_models(request, user: UserModel = None):
models = await get_all_base_models(request, user=user)
# If there are no models, return an empty list
if len(models) == 0:
@@ -142,7 +143,7 @@ async def get_all_models(request):
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
owned_by = model.get("owned_by", "unknown owner")
if "pipe" in model:
pipe = model["pipe"]
break

View File

@@ -140,7 +140,14 @@ class OAuthManager:
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
# Nested claim search for groups claim
if oauth_claim:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
user_oauth_groups = claim_data if isinstance(claim_data, list) else []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
@@ -239,11 +246,46 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
email = user_data.get(email_claim, "")
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
if provider == "github":
try:
access_token = token.get("access_token")
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.github.com/user/emails", headers=headers
) as resp:
if resp.ok:
emails = await resp.json()
# use the primary email as the user's email
primary_email = next(
(e["email"] for e in emails if e.get("primary")),
None,
)
if primary_email:
email = primary_email
else:
log.warning(
"No primary email found in GitHub response"
)
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
else:
log.warning("Failed to fetch GitHub email")
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
except Exception as e:
log.warning(f"Error fetching GitHub email: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
email = email.lower()
if (
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
@@ -273,21 +315,10 @@ class OAuthManager:
if not user:
user_count = Users.get_num_users()
if (
request.app.state.USER_COUNT
and user_count >= request.app.state.USER_COUNT
):
raise HTTPException(
403,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(
user_data.get("email", "").lower()
)
existing_user = Users.get_user_by_email(email)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)

View File

@@ -4,6 +4,7 @@ from open_webui.utils.misc import (
)
from typing import Callable, Optional
import json
# inplace function: form_data is modified
@@ -67,38 +68,49 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"temperature",
"top_p",
"seed",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
"num_gpu",
]
mappings = {i: lambda x: x for i in opts}
form_data = apply_model_params_to_body(params, form_data, mappings)
# Convert OpenAI parameter names to Ollama parameter names if needed.
name_differences = {
"max_tokens": "num_predict",
"frequency_penalty": "repeat_penalty",
}
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
params[value] = params[key]
del params[key]
return form_data
# See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8
mappings = {
"temperature": float,
"top_p": float,
"seed": lambda x: x,
"mirostat": int,
"mirostat_eta": float,
"mirostat_tau": float,
"num_ctx": int,
"num_batch": int,
"num_keep": int,
"num_predict": int,
"repeat_last_n": int,
"top_k": int,
"min_p": float,
"typical_p": float,
"repeat_penalty": float,
"presence_penalty": float,
"frequency_penalty": float,
"penalize_newline": bool,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
"numa": bool,
"num_gpu": int,
"main_gpu": int,
"low_vram": bool,
"vocab_only": bool,
"use_mmap": bool,
"use_mlock": bool,
"num_thread": int,
}
return apply_model_params_to_body(params, form_data, mappings)
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
@@ -109,11 +121,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
new_message = {"role": message["role"]}
content = message.get("content", [])
tool_calls = message.get("tool_calls", None)
tool_call_id = message.get("tool_call_id", None)
# Check if the content is a string (just a simple message)
if isinstance(content, str):
if isinstance(content, str) and not tool_calls:
# If the content is a string, it's pure text
new_message["content"] = content
# If message is a tool call, add the tool call id to the message
if tool_call_id:
new_message["tool_call_id"] = tool_call_id
elif tool_calls:
# If tool calls are present, add them to the message
ollama_tool_calls = []
for tool_call in tool_calls:
ollama_tool_call = {
"index": tool_call.get("index", 0),
"id": tool_call.get("id", None),
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": json.loads(
tool_call.get("function", {}).get("arguments", {})
),
},
}
ollama_tool_calls.append(ollama_tool_call)
new_message["tool_calls"] = ollama_tool_calls
# Put the content to empty string (Ollama requires an empty string for tool calls)
new_message["content"] = ""
else:
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
content_text = ""
@@ -174,33 +213,28 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
ollama_payload["format"] = openai_payload["format"]
# If there are advanced parameters in the payload, format them in Ollama's options field
ollama_options = {}
if openai_payload.get("options"):
ollama_payload["options"] = openai_payload["options"]
ollama_options = openai_payload["options"]
# Handle parameters which map directly
for param in ["temperature", "top_p", "seed"]:
if param in openai_payload:
ollama_options[param] = openai_payload[param]
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_tokens" in ollama_options:
ollama_options["num_predict"] = ollama_options["max_tokens"]
del ollama_options[
"max_tokens"
] # To prevent Ollama warning of invalid option provided
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_completion_tokens" in openai_payload:
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
elif "max_tokens" in openai_payload:
ollama_options["num_predict"] = openai_payload["max_tokens"]
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
if "system" in ollama_options:
ollama_payload["system"] = ollama_options["system"]
del ollama_options[
"system"
] # To prevent Ollama warning of invalid option provided
# Handle frequency / presence_penalty, which needs renaming and checking
if "frequency_penalty" in openai_payload:
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
# Add options to payload if any have been set
if ollama_options:
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
if "stop" in openai_payload:
ollama_options = ollama_payload.get("options", {})
ollama_options["stop"] = openai_payload.get("stop")
ollama_payload["options"] = ollama_options
if "metadata" in openai_payload:

View File

@@ -45,7 +45,7 @@ def extract_frontmatter(content):
frontmatter[key.strip()] = value.strip()
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Failed to extract frontmatter: {e}")
return {}
return frontmatter

View File

@@ -24,17 +24,8 @@ def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
return openai_tool_calls
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
model = ollama_response.get("model", "ollama")
message_content = ollama_response.get("message", {}).get("content", "")
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
if tool_calls:
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
data = ollama_response
usage = {
def convert_ollama_usage_to_openai(data: dict) -> dict:
return {
"response_token/s": (
round(
(
@@ -66,14 +57,42 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
"total_duration": data.get("total_duration", 0),
"load_duration": data.get("load_duration", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
"prompt_tokens": int(
data.get("prompt_eval_count", 0)
), # This is the OpenAI compatible key
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
"eval_count": data.get("eval_count", 0),
"completion_tokens": int(
data.get("eval_count", 0)
), # This is the OpenAI compatible key
"eval_duration": data.get("eval_duration", 0),
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
(data.get("total_duration", 0) or 0) // 1_000_000_000
),
"total_tokens": int( # This is the OpenAI compatible key
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
),
"completion_tokens_details": { # This is the OpenAI compatible key
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
}
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
model = ollama_response.get("model", "ollama")
message_content = ollama_response.get("message", {}).get("content", "")
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
if tool_calls:
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
data = ollama_response
usage = convert_ollama_usage_to_openai(data)
response = openai_chat_completion_message_template(
model, message_content, openai_tool_calls, usage
)
@@ -85,7 +104,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
data = json.loads(data)
model = data.get("model", "ollama")
message_content = data.get("message", {}).get("content", "")
message_content = data.get("message", {}).get("content", None)
tool_calls = data.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
@@ -96,48 +115,10 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
usage = None
if done:
usage = {
"response_token/s": (
round(
(
(
data.get("eval_count", 0)
/ ((data.get("eval_duration", 0) / 10_000_000))
)
* 100
),
2,
)
if data.get("eval_duration", 0) > 0
else "N/A"
),
"prompt_token/s": (
round(
(
(
data.get("prompt_eval_count", 0)
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
)
* 100
),
2,
)
if data.get("prompt_eval_duration", 0) > 0
else "N/A"
),
"total_duration": data.get("total_duration", 0),
"load_duration": data.get("load_duration", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
"eval_count": data.get("eval_count", 0),
"eval_duration": data.get("eval_duration", 0),
"approximate_total": (
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
)((data.get("total_duration", 0) or 0) // 1_000_000_000),
}
usage = convert_ollama_usage_to_openai(data)
data = openai_chat_chunk_message_template(
model, message_content if not done else None, openai_tool_calls, usage
model, message_content, openai_tool_calls, usage
)
line = f"data: {json.dumps(data)}\n\n"

View File

@@ -22,7 +22,7 @@ def get_task_model_id(
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if models[task_model_id].get("owned_by") == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else: