mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'origin' into logit_bias
This commit is contained in:
@@ -9,7 +9,6 @@ from pathlib import Path
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||
@@ -44,7 +43,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
|
||||
|
||||
# Function to run the alembic migrations
|
||||
def run_migrations():
|
||||
print("Running migrations")
|
||||
log.info("Running migrations")
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
@@ -57,7 +56,7 @@ def run_migrations():
|
||||
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
log.exception(f"Error running migrations: {e}")
|
||||
|
||||
|
||||
run_migrations()
|
||||
@@ -588,20 +587,6 @@ load_oauth_providers()
|
||||
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
|
||||
|
||||
|
||||
def override_static(path: str, content: str):
|
||||
# Ensure path is safe
|
||||
if "/" in path or ".." in path:
|
||||
log.error(f"Invalid path: {path}")
|
||||
return
|
||||
|
||||
file_path = os.path.join(STATIC_DIR, path)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
|
||||
|
||||
|
||||
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
|
||||
|
||||
if frontend_favicon.exists():
|
||||
@@ -692,12 +677,20 @@ S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
|
||||
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
|
||||
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
|
||||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
|
||||
S3_USE_ACCELERATE_ENDPOINT = (
|
||||
os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true"
|
||||
)
|
||||
S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None)
|
||||
|
||||
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
|
||||
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
|
||||
"GOOGLE_APPLICATION_CREDENTIALS_JSON", None
|
||||
)
|
||||
|
||||
AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None)
|
||||
AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None)
|
||||
AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None)
|
||||
|
||||
####################################
|
||||
# File Upload DIR
|
||||
####################################
|
||||
@@ -797,6 +790,9 @@ ENABLE_OPENAI_API = PersistentConfig(
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
||||
|
||||
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||
GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
|
||||
|
||||
|
||||
if OPENAI_API_BASE_URL == "":
|
||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
@@ -1101,7 +1097,7 @@ try:
|
||||
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
|
||||
banners = [BannerModel(**banner) for banner in banners]
|
||||
except Exception as e:
|
||||
print(f"Error loading WEBUI_BANNERS: {e}")
|
||||
log.exception(f"Error loading WEBUI_BANNERS: {e}")
|
||||
banners = []
|
||||
|
||||
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
|
||||
@@ -1377,6 +1373,44 @@ Responses from models: {{responses}}"""
|
||||
# Code Interpreter
|
||||
####################################
|
||||
|
||||
|
||||
CODE_EXECUTION_ENGINE = PersistentConfig(
|
||||
"CODE_EXECUTION_ENGINE",
|
||||
"code_execution.engine",
|
||||
os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_URL = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_URL",
|
||||
"code_execution.jupyter.url",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_AUTH",
|
||||
"code_execution.jupyter.auth",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN",
|
||||
"code_execution.jupyter.auth_token",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD",
|
||||
"code_execution.jupyter.auth_password",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT",
|
||||
"code_execution.jupyter.timeout",
|
||||
int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")),
|
||||
)
|
||||
|
||||
ENABLE_CODE_INTERPRETER = PersistentConfig(
|
||||
"ENABLE_CODE_INTERPRETER",
|
||||
"code_interpreter.enable",
|
||||
@@ -1398,26 +1432,48 @@ CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
|
||||
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_URL",
|
||||
"code_interpreter.jupyter.url",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "")
|
||||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||
"code_interpreter.jupyter.auth",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
|
||||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||
"code_interpreter.jupyter.auth_token",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||
"code_interpreter.jupyter.auth_password",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
|
||||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
|
||||
"code_interpreter.jupyter.timeout",
|
||||
int(
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1445,21 +1501,27 @@ VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
|
||||
|
||||
# Chroma
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
||||
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
||||
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
|
||||
# Comma-separated list of header=value pairs
|
||||
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
||||
if CHROMA_HTTP_HEADERS:
|
||||
CHROMA_HTTP_HEADERS = dict(
|
||||
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
import chromadb
|
||||
|
||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
||||
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
||||
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get(
|
||||
"CHROMA_CLIENT_AUTH_CREDENTIALS", ""
|
||||
)
|
||||
else:
|
||||
CHROMA_HTTP_HEADERS = None
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# Comma-separated list of header=value pairs
|
||||
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
||||
if CHROMA_HTTP_HEADERS:
|
||||
CHROMA_HTTP_HEADERS = dict(
|
||||
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
||||
)
|
||||
else:
|
||||
CHROMA_HTTP_HEADERS = None
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
# Milvus
|
||||
@@ -1513,6 +1575,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig(
|
||||
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
|
||||
)
|
||||
|
||||
ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig(
|
||||
"ENABLE_ONEDRIVE_INTEGRATION",
|
||||
"onedrive.enable",
|
||||
os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ONEDRIVE_CLIENT_ID = PersistentConfig(
|
||||
"ONEDRIVE_CLIENT_ID",
|
||||
"onedrive.client_id",
|
||||
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
|
||||
)
|
||||
|
||||
# RAG Content Extraction
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
@@ -1526,6 +1600,26 @@ TIKA_SERVER_URL = PersistentConfig(
|
||||
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||
"rag.document_intelligence_endpoint",
|
||||
os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""),
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_KEY",
|
||||
"rag.document_intelligence_key",
|
||||
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
|
||||
)
|
||||
|
||||
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL",
|
||||
"rag.bypass_embedding_and_retrieval",
|
||||
os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
RAG_TOP_K = PersistentConfig(
|
||||
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
|
||||
)
|
||||
@@ -1541,6 +1635,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
|
||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_FULL_CONTEXT = PersistentConfig(
|
||||
"RAG_FULL_CONTEXT",
|
||||
"rag.full_context",
|
||||
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_FILE_MAX_COUNT = PersistentConfig(
|
||||
"RAG_FILE_MAX_COUNT",
|
||||
"rag.file.max_count",
|
||||
@@ -1655,7 +1755,7 @@ Respond to the user query using the provided context, incorporating inline citat
|
||||
- Respond in the same language as the user's query.
|
||||
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
|
||||
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
|
||||
- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**
|
||||
- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `<source_id>` tag is explicitly provided in the context.**
|
||||
- Do not cite if the <source_id> tag is not provided in the context.
|
||||
- Do not use XML tags in your response.
|
||||
- Ensure citations are concise and directly related to the information provided.
|
||||
@@ -1736,6 +1836,12 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
||||
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
||||
)
|
||||
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL",
|
||||
"rag.web.search.bypass_embedding_and_retrieval",
|
||||
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||
)
|
||||
|
||||
# You can provide a list of your own websites to filter after performing a web search.
|
||||
# This ensures the highest level of safety and reliability of the information sources.
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
@@ -1883,10 +1989,34 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
RAG_WEB_LOADER_ENGINE = PersistentConfig(
|
||||
"RAG_WEB_LOADER_ENGINE",
|
||||
"rag.web.loader.engine",
|
||||
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_TRUST_ENV",
|
||||
"rag.web.search.trust_env",
|
||||
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False),
|
||||
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
|
||||
)
|
||||
|
||||
PLAYWRIGHT_WS_URI = PersistentConfig(
|
||||
"PLAYWRIGHT_WS_URI",
|
||||
"rag.web.loader.engine.playwright.ws.uri",
|
||||
os.environ.get("PLAYWRIGHT_WS_URI", None),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_KEY = PersistentConfig(
|
||||
"FIRECRAWL_API_KEY",
|
||||
"firecrawl.api_key",
|
||||
os.environ.get("FIRECRAWL_API_KEY", ""),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_BASE_URL = PersistentConfig(
|
||||
"FIRECRAWL_API_BASE_URL",
|
||||
"firecrawl.api_url",
|
||||
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
|
||||
)
|
||||
|
||||
####################################
|
||||
@@ -2099,6 +2229,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
|
||||
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
|
||||
"IMAGES_GEMINI_API_BASE_URL",
|
||||
"image_generation.gemini.api_base_url",
|
||||
os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
|
||||
)
|
||||
IMAGES_GEMINI_API_KEY = PersistentConfig(
|
||||
"IMAGES_GEMINI_API_KEY",
|
||||
"image_generation.gemini.api_key",
|
||||
os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY),
|
||||
)
|
||||
|
||||
IMAGE_SIZE = PersistentConfig(
|
||||
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
||||
)
|
||||
@@ -2275,7 +2416,7 @@ LDAP_SEARCH_BASE = PersistentConfig(
|
||||
LDAP_SEARCH_FILTERS = PersistentConfig(
|
||||
"LDAP_SEARCH_FILTER",
|
||||
"ldap.server.search_filter",
|
||||
os.environ.get("LDAP_SEARCH_FILTER", ""),
|
||||
os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")),
|
||||
)
|
||||
|
||||
LDAP_USE_TLS = PersistentConfig(
|
||||
|
||||
@@ -419,3 +419,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||
# Where to store log file
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
# Maximum size of a file before rotating into a new log file
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
|
||||
try:
|
||||
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||
except ValueError:
|
||||
MAX_BODY_LOG_SIZE = 2048
|
||||
|
||||
# Comma separated list for urls to exclude from audit
|
||||
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
|
||||
","
|
||||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import sys
|
||||
import inspect
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
@@ -76,11 +77,13 @@ async def get_function_models(request):
|
||||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
|
||||
# Handle pipes being a list, sync function, or async function
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
sub_pipes = function_module.pipes()
|
||||
if asyncio.iscoroutinefunction(function_module.pipes):
|
||||
sub_pipes = await function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
|
||||
@@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils import logger
|
||||
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
||||
from open_webui.utils.logger import start_logger
|
||||
from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
@@ -95,12 +98,19 @@ from open_webui.config import (
|
||||
OLLAMA_API_CONFIGS,
|
||||
# OpenAI
|
||||
ENABLE_OPENAI_API,
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
OPENAI_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
# Code Interpreter
|
||||
# Code Execution
|
||||
CODE_EXECUTION_ENGINE,
|
||||
CODE_EXECUTION_JUPYTER_URL,
|
||||
CODE_EXECUTION_JUPYTER_AUTH,
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
ENABLE_CODE_INTERPRETER,
|
||||
CODE_INTERPRETER_ENGINE,
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
@@ -108,6 +118,7 @@ from open_webui.config import (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
# Image
|
||||
AUTOMATIC1111_API_AUTH,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
@@ -126,6 +137,8 @@ from open_webui.config import (
|
||||
IMAGE_STEPS,
|
||||
IMAGES_OPENAI_API_BASE_URL,
|
||||
IMAGES_OPENAI_API_KEY,
|
||||
IMAGES_GEMINI_API_BASE_URL,
|
||||
IMAGES_GEMINI_API_KEY,
|
||||
# Audio
|
||||
AUDIO_STT_ENGINE,
|
||||
AUDIO_STT_MODEL,
|
||||
@@ -140,6 +153,10 @@ from open_webui.config import (
|
||||
AUDIO_TTS_VOICE,
|
||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
WHISPER_MODEL,
|
||||
DEEPGRAM_API_KEY,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
@@ -147,6 +164,8 @@ from open_webui.config import (
|
||||
# Retrieval
|
||||
RAG_TEMPLATE,
|
||||
DEFAULT_RAG_TEMPLATE,
|
||||
RAG_FULL_CONTEXT,
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
@@ -166,6 +185,8 @@ from open_webui.config import (
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
RAG_TOP_K,
|
||||
RAG_TEXT_SPLITTER,
|
||||
TIKTOKEN_ENCODING_NAME,
|
||||
@@ -174,6 +195,7 @@ from open_webui.config import (
|
||||
YOUTUBE_LOADER_PROXY_URL,
|
||||
# Retrieval (Web Search)
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
RAG_WEB_SEARCH_TRUST_ENV,
|
||||
@@ -200,11 +222,13 @@ from open_webui.config import (
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
GOOGLE_DRIVE_CLIENT_ID,
|
||||
GOOGLE_DRIVE_API_KEY,
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
ENABLE_ONEDRIVE_INTEGRATION,
|
||||
UPLOAD_DIR,
|
||||
# WebUI
|
||||
WEBUI_AUTH,
|
||||
@@ -283,8 +307,11 @@ from open_webui.config import (
|
||||
reset_config,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AUDIT_EXCLUDED_PATHS,
|
||||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
MAX_BODY_LOG_SIZE,
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
@@ -369,6 +396,7 @@ https://github.com/open-webui/open-webui
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
start_logger()
|
||||
if RESET_CONFIG_ON_START:
|
||||
reset_config()
|
||||
|
||||
@@ -509,6 +537,9 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
||||
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||
|
||||
|
||||
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
@@ -516,6 +547,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
|
||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||
@@ -543,9 +576,13 @@ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
|
||||
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
@@ -569,7 +606,11 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
|
||||
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
|
||||
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
app.state.ef = None
|
||||
@@ -613,10 +654,19 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE INTERPRETER
|
||||
# CODE EXECUTION
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
|
||||
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
|
||||
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
|
||||
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
@@ -629,6 +679,7 @@ app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
|
||||
########################################
|
||||
#
|
||||
@@ -643,6 +694,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
|
||||
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||
|
||||
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
||||
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
||||
|
||||
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
|
||||
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
@@ -844,6 +898,19 @@ app.include_router(
|
||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||
|
||||
|
||||
try:
|
||||
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
|
||||
audit_level = AuditLevel.NONE
|
||||
|
||||
if audit_level != AuditLevel.NONE:
|
||||
app.add_middleware(
|
||||
AuditLoggingMiddleware,
|
||||
audit_level=audit_level,
|
||||
excluded_paths=AUDIT_EXCLUDED_PATHS,
|
||||
max_body_size=MAX_BODY_LOG_SIZE,
|
||||
)
|
||||
##################################
|
||||
#
|
||||
# Chat Endpoints
|
||||
@@ -876,7 +943,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
return filtered_models
|
||||
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
@@ -905,7 +972,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
@app.get("/api/models/base")
|
||||
async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
||||
models = await get_all_base_models(request)
|
||||
models = await get_all_base_models(request, user=user)
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@@ -916,7 +983,7 @@ async def chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
model_item = form_data.pop("model_item", {})
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
@@ -952,7 +1019,7 @@ async def chat_completion(
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
"variables": form_data.get("variables", None),
|
||||
"model": model_info,
|
||||
"model": model_info.model_dump() if model_info else model,
|
||||
"direct": model_item.get("direct", False),
|
||||
**(
|
||||
{"function_calling": "native"}
|
||||
@@ -1111,6 +1178,7 @@ async def get_app_config(request: Request):
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
}
|
||||
if user is not None
|
||||
else {}
|
||||
@@ -1120,6 +1188,9 @@ async def get_app_config(request: Request):
|
||||
{
|
||||
"default_models": app.state.config.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
"code": {
|
||||
"engine": app.state.config.CODE_EXECUTION_ENGINE,
|
||||
},
|
||||
"audio": {
|
||||
"tts": {
|
||||
"engine": app.state.config.TTS_ENGINE,
|
||||
@@ -1139,6 +1210,7 @@ async def get_app_config(request: Request):
|
||||
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
||||
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
||||
},
|
||||
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
|
||||
}
|
||||
if user is not None
|
||||
else {}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
@@ -5,7 +6,7 @@ from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
@@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
|
||||
# Chat DB Schema
|
||||
####################
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chat"
|
||||
@@ -670,7 +674,7 @@ class ChatTable:
|
||||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
print(len(all_chats))
|
||||
log.info(f"The number of chats: {len(all_chats)}")
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
@@ -731,7 +735,7 @@ class ChatTable:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
print(db.bind.dialect.name)
|
||||
log.info(f"DB dialect name: {db.bind.dialect.name}")
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 querying for tags within the meta JSON field
|
||||
query = query.filter(
|
||||
@@ -752,7 +756,7 @@ class ChatTable:
|
||||
)
|
||||
|
||||
all_chats = query.all()
|
||||
print("all_chats", all_chats)
|
||||
log.debug(f"all_chats: {all_chats}")
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
@@ -810,7 +814,7 @@ class ChatTable:
|
||||
count = query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
print(f"Count of chats for tag '{tag_name}':", count)
|
||||
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class FeedbackTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new feedback: {e}")
|
||||
return None
|
||||
|
||||
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
||||
|
||||
@@ -119,7 +119,7 @@ class FilesTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error inserting a new file: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||
|
||||
@@ -82,7 +82,7 @@ class FolderTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error inserting a new folder: {e}")
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
|
||||
@@ -105,7 +105,7 @@ class FunctionsTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error creating a new function: {e}")
|
||||
return None
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
@@ -170,7 +170,7 @@ class FunctionsTable:
|
||||
function = db.get(Function, id)
|
||||
return function.valves if function.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Error getting function valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_function_valves_by_id(
|
||||
@@ -202,7 +202,9 @@ class FunctionsTable:
|
||||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
@@ -225,7 +227,9 @@ class FunctionsTable:
|
||||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
|
||||
5
backend/open_webui/models/models.py
Normal file → Executable file
5
backend/open_webui/models/models.py
Normal file → Executable file
@@ -166,7 +166,7 @@ class ModelsTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to insert a new model: {e}")
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> list[ModelModel]:
|
||||
@@ -246,8 +246,7 @@ class ModelsTable:
|
||||
db.refresh(model)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
|
||||
@@ -61,7 +61,7 @@ class TagTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error inserting a new tag: {e}")
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
|
||||
@@ -131,7 +131,7 @@ class ToolsTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error creating a new tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
@@ -175,7 +175,7 @@ class ToolsTable:
|
||||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
@@ -204,7 +204,9 @@ class ToolsTable:
|
||||
|
||||
return user_settings["tools"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
@@ -227,7 +229,9 @@ class ToolsTable:
|
||||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
|
||||
@@ -4,6 +4,7 @@ import ftfy
|
||||
import sys
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
BSHTMLLoader,
|
||||
CSVLoader,
|
||||
Docx2txtLoader,
|
||||
@@ -76,6 +77,7 @@ known_source_ext = [
|
||||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
]
|
||||
|
||||
|
||||
@@ -147,6 +149,27 @@ class Loader:
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
|
||||
and (
|
||||
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
]
|
||||
)
|
||||
):
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from colbert.infra import ColBERTConfig
|
||||
from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ColBERT:
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
print("ColBERT: Loading model", name)
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
DOCKER = kwargs.get("env") == "docker"
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
|
||||
import asyncio
|
||||
import requests
|
||||
import hashlib
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
@@ -14,8 +15,10 @@ from langchain_core.documents import Document
|
||||
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -80,7 +83,20 @@ def query_doc(
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_doc(collection_name: str, user: UserModel = None):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting doc {collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@@ -137,47 +153,80 @@ def query_doc_with_hybrid_search(
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> list[dict]:
|
||||
def merge_get_results(get_results: list[dict]) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined_distances = []
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
combined_ids = []
|
||||
|
||||
for data in query_results:
|
||||
combined_distances.extend(data["distances"][0])
|
||||
for data in get_results:
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
combined_ids.extend(data["ids"][0])
|
||||
|
||||
# Create a list of tuples (distance, document, metadata)
|
||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"documents": [combined_documents],
|
||||
"metadatas": [combined_metadatas],
|
||||
"ids": [combined_ids],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined = []
|
||||
seen_hashes = set() # To store unique document hashes
|
||||
|
||||
for data in query_results:
|
||||
distances = data["distances"][0]
|
||||
documents = data["documents"][0]
|
||||
metadatas = data["metadatas"][0]
|
||||
|
||||
for distance, document, metadata in zip(distances, documents, metadatas):
|
||||
if isinstance(document, str):
|
||||
doc_hash = hashlib.md5(
|
||||
document.encode()
|
||||
).hexdigest() # Compute a hash for uniqueness
|
||||
|
||||
if doc_hash not in seen_hashes:
|
||||
seen_hashes.add(doc_hash)
|
||||
combined.append((distance, document, metadata))
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
# We don't have anything :-(
|
||||
if not combined:
|
||||
sorted_distances = []
|
||||
sorted_documents = []
|
||||
sorted_metadatas = []
|
||||
else:
|
||||
# Unzip the sorted list
|
||||
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
||||
# Slice to keep only the top k elements
|
||||
sorted_distances, sorted_documents, sorted_metadatas = (
|
||||
zip(*combined[:k]) if combined else ([], [], [])
|
||||
)
|
||||
|
||||
# Slicing the lists to include only k elements
|
||||
sorted_distances = list(sorted_distances)[:k]
|
||||
sorted_documents = list(sorted_documents)[:k]
|
||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
||||
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"distances": [sorted_distances],
|
||||
"documents": [sorted_documents],
|
||||
"metadatas": [sorted_metadatas],
|
||||
# Create and return the output dictionary
|
||||
return {
|
||||
"distances": [list(sorted_distances)],
|
||||
"documents": [list(sorted_documents)],
|
||||
"metadatas": [list(sorted_metadatas)],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = get_doc(collection_name=collection_name)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_get_results(results)
|
||||
|
||||
|
||||
def query_collection(
|
||||
@@ -290,6 +339,7 @@ def get_embedding_function(
|
||||
|
||||
|
||||
def get_sources_from_files(
|
||||
request,
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
@@ -297,21 +347,74 @@ def get_sources_from_files(
|
||||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
):
|
||||
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for file in files:
|
||||
if file.get("context") == "full":
|
||||
|
||||
context = None
|
||||
if file.get("docs"):
|
||||
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
context = {
|
||||
"documents": [[doc.get("content") for doc in file.get("docs")]],
|
||||
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
|
||||
}
|
||||
elif file.get("context") == "full":
|
||||
# Manual Full Mode Toggle
|
||||
context = {
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
||||
}
|
||||
else:
|
||||
context = None
|
||||
elif (
|
||||
file.get("type") != "web_search"
|
||||
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if file.get("type") == "collection":
|
||||
file_ids = file.get("data", {}).get("file_ids", [])
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
for file_id in file_ids:
|
||||
file_object = Files.get_file_by_id(file_id)
|
||||
|
||||
if file_object:
|
||||
documents.append(file_object.data.get("content", ""))
|
||||
metadatas.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
)
|
||||
|
||||
context = {
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
|
||||
elif file.get("id"):
|
||||
file_object = Files.get_file_by_id(file.get("id"))
|
||||
if file_object:
|
||||
context = {
|
||||
"documents": [[file_object.data.get("content", "")]],
|
||||
"metadatas": [
|
||||
[
|
||||
{
|
||||
"file_id": file.get("id"),
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
else:
|
||||
collection_names = []
|
||||
if file.get("type") == "collection":
|
||||
if file.get("legacy"):
|
||||
@@ -331,42 +434,50 @@ def get_sources_from_files(
|
||||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
if full_context:
|
||||
try:
|
||||
context = get_all_items_from_collections(collection_names)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
else:
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
if context:
|
||||
if "data" in file:
|
||||
del file["data"]
|
||||
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
sources = []
|
||||
@@ -463,7 +574,7 @@ def generate_openai_batch_embeddings(
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -497,7 +608,7 @@ def generate_ollama_batch_embeddings(
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
8
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file → Executable file
8
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file → Executable file
@@ -1,4 +1,5 @@
|
||||
import chromadb
|
||||
import logging
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
@@ -16,6 +17,10 @@ from open_webui.config import (
|
||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
@@ -102,8 +107,7 @@ class ChromaClient:
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except:
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
@@ -10,6 +10,10 @@ from open_webui.config import (
|
||||
MILVUS_DB,
|
||||
MILVUS_TOKEN,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
@@ -168,7 +172,7 @@ class MilvusClient:
|
||||
try:
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
print("remaining", remaining)
|
||||
log.info(f"remaining: {remaining}")
|
||||
current_fetch = min(
|
||||
max_limit, remaining
|
||||
) # Determine how many items to fetch in this iteration
|
||||
@@ -195,10 +199,12 @@ class MilvusClient:
|
||||
if results_count < current_fetch:
|
||||
break
|
||||
|
||||
print(all_results)
|
||||
log.debug(all_results)
|
||||
return self._result_to_get_result([all_results])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(
|
||||
f"Error querying collection {collection_name} with limit {limit}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
from sqlalchemy import (
|
||||
cast,
|
||||
column,
|
||||
@@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
@@ -82,10 +88,10 @@ class PgvectorClient:
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
print("Initialization complete.")
|
||||
log.info("Initialization complete.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during initialization: {e}")
|
||||
log.exception(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
@@ -150,12 +156,12 @@ class PgvectorClient:
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
print(
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during insert: {e}")
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -184,10 +190,12 @@ class PgvectorClient:
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during upsert: {e}")
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -278,7 +286,7 @@ class PgvectorClient:
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during search: {e}")
|
||||
log.exception(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
@@ -310,7 +318,7 @@ class PgvectorClient:
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during query: {e}")
|
||||
log.exception(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(
|
||||
@@ -334,7 +342,7 @@ class PgvectorClient:
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
print(f"Error during get: {e}")
|
||||
log.exception(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
@@ -356,22 +364,22 @@ class PgvectorClient:
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during delete: {e}")
|
||||
log.exception(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
print(
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during reset: {e}")
|
||||
log.exception(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -387,9 +395,9 @@ class PgvectorClient:
|
||||
)
|
||||
return exists
|
||||
except Exception as e:
|
||||
print(f"Error checking collection existence: {e}")
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
print(f"Collection '{collection_name}' deleted.")
|
||||
log.info(f"Collection '{collection_name}' deleted.")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
@@ -6,9 +7,13 @@ from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
@@ -49,7 +54,7 @@ class QdrantClient:
|
||||
),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
@@ -120,7 +125,7 @@ class QdrantClient:
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
||||
@@ -27,8 +27,7 @@ def search_tavily(
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
include_domain = filter_list
|
||||
response = requests.post(url, include_domain, json=data)
|
||||
response = requests.post(url, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
@@ -1,22 +1,38 @@
|
||||
import socket
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
import validators
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
|
||||
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
Literal,
|
||||
)
|
||||
import aiohttp
|
||||
import certifi
|
||||
import validators
|
||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -68,6 +84,314 @@ def resolve_hostname(hostname):
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def extract_metadata(soup, url):
|
||||
metadata = {"source": url}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get("content", "No description found.")
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
return metadata
|
||||
|
||||
|
||||
def verify_ssl_cert(url: str) -> bool:
|
||||
"""Verify SSL certificate for the given URL."""
|
||||
if not url.startswith("https://"):
|
||||
return True
|
||||
|
||||
try:
|
||||
hostname = url.split("://")[-1].split("/")[0]
|
||||
context = ssl.create_default_context(cafile=certifi.where())
|
||||
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
|
||||
s.connect((hostname, 443))
|
||||
return True
|
||||
except ssl.SSLError:
|
||||
return False
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class SafeFireCrawlLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths,
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
requests_per_second: Optional[float] = None,
|
||||
continue_on_failure: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
mode: Literal["crawl", "scrape", "map"] = "crawl",
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
"""Concurrent document loader for FireCrawl operations.
|
||||
|
||||
Executes multiple FireCrawlLoader instances concurrently using thread pooling
|
||||
to improve bulk processing efficiency.
|
||||
Args:
|
||||
web_paths: List of URLs/paths to process.
|
||||
verify_ssl: If True, verify SSL certificates.
|
||||
trust_env: If True, use proxy settings from environment variables.
|
||||
requests_per_second: Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
api_key: API key for FireCrawl service. Defaults to None
|
||||
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
||||
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
||||
mode: Operation mode selection:
|
||||
- 'crawl': Website crawling mode (default)
|
||||
- 'scrape': Direct page scraping
|
||||
- 'map': Site map generation
|
||||
proxy: Proxy override settings for the FireCrawl API.
|
||||
params: The parameters to pass to the Firecrawl API.
|
||||
Examples include crawlerOptions.
|
||||
For more details, visit: https://github.com/mendableai/firecrawl-py
|
||||
"""
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
self.web_paths = web_paths
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
self.trust_env = trust_env
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.mode = mode
|
||||
self.params = params
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load documents concurrently using FireCrawl."""
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
|
||||
async def alazy_load(self):
|
||||
"""Async version of lazy_load."""
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
|
||||
|
||||
Attributes:
|
||||
web_paths (List[str]): List of URLs to load.
|
||||
verify_ssl (bool): If True, verify SSL certificates.
|
||||
trust_env (bool): If True, use proxy settings from environment variables.
|
||||
requests_per_second (Optional[float]): Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
headless (bool): If True, the browser will run in headless mode.
|
||||
proxy (dict): Proxy override settings for the Playwright session.
|
||||
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: List[str],
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
requests_per_second: Optional[float] = None,
|
||||
continue_on_failure: bool = True,
|
||||
headless: bool = True,
|
||||
remove_selectors: Optional[List[str]] = None,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
playwright_ws_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with additional safety parameters and remote browser support."""
|
||||
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
|
||||
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
|
||||
super().__init__(
|
||||
urls=web_paths,
|
||||
continue_on_failure=continue_on_failure,
|
||||
headless=headless if playwright_ws_url is None else False,
|
||||
remove_selectors=remove_selectors,
|
||||
proxy=proxy,
|
||||
)
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
self.playwright_ws_url = playwright_ws_url
|
||||
self.trust_env = trust_env
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Safely load URLs synchronously with support for remote browser."""
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
with sync_playwright() as p:
|
||||
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||
if self.playwright_ws_url:
|
||||
browser = p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
page = browser.new_page()
|
||||
response = page.goto(url)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
text = self.evaluator.evaluate(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
browser.close()
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Safely load URLs asynchronously with support for remote browser."""
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
async with async_playwright() as p:
|
||||
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||
if self.playwright_ws_url:
|
||||
browser = await p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = await p.chromium.launch(
|
||||
headless=self.headless, proxy=self.proxy
|
||||
)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
page = await browser.new_page()
|
||||
response = await page.goto(url)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
text = await self.evaluator.evaluate_async(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
await browser.close()
|
||||
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
@@ -143,20 +467,12 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
|
||||
# Build metadata
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
metadata = extract_metadata(soup, path)
|
||||
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Log the error and continue with the next URL
|
||||
log.error(f"Error loading {path}: {e}")
|
||||
log.exception(e, "Error loading %s", path)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
@@ -179,6 +495,12 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
|
||||
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
|
||||
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
@@ -188,10 +510,29 @@ def get_web_loader(
|
||||
# Check if the URLs are valid
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
return SafeWebBaseLoader(
|
||||
web_path=safe_urls,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=requests_per_second,
|
||||
continue_on_failure=True,
|
||||
trust_env=trust_env,
|
||||
web_loader_args = {
|
||||
"web_paths": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True,
|
||||
"trust_env": trust_env,
|
||||
}
|
||||
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||
|
||||
# Create the appropriate WebLoader based on the configuration
|
||||
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
log.debug(
|
||||
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
|
||||
return web_loader
|
||||
|
||||
@@ -37,6 +37,7 @@ from open_webui.config import (
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
@@ -70,7 +71,7 @@ from pydub.utils import mediainfo
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
log.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
@@ -87,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
@@ -265,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
@@ -323,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
@@ -380,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
headers={
|
||||
@@ -458,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
print("transcribe", file_path)
|
||||
log.info(f"transcribe: {file_path}")
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
@@ -670,7 +679,22 @@ def transcription(
|
||||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
available_models = data.get("models", [])
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
else:
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
@@ -701,14 +725,37 @@ def get_available_voices(request) -> dict:
|
||||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
voices_list = data.get("voices", [])
|
||||
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
else:
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
|
||||
@@ -31,10 +31,7 @@ from open_webui.env import (
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import (
|
||||
OPENID_PROVIDER_URL,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
@@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions
|
||||
from typing import Optional, List
|
||||
|
||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
if ENABLE_LDAP.value:
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -252,14 +251,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
if (
|
||||
request.app.state.USER_COUNT
|
||||
and user_count >= request.app.state.USER_COUNT
|
||||
):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
@@ -423,7 +414,6 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
@@ -434,16 +424,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
else:
|
||||
if user_count != 0:
|
||||
if Users.get_num_users() != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
@@ -546,7 +532,8 @@ async def signout(request: Request, response: Response):
|
||||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
headers=response.headers,
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -612,7 +599,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
print(admin_email, admin_name)
|
||||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
|
||||
@@ -70,6 +70,12 @@ async def set_direct_connections_config(
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
CODE_EXECUTION_ENGINE: str
|
||||
CODE_EXECUTION_JUPYTER_URL: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||
@@ -77,11 +83,18 @@ class CodeInterpreterConfigForm(BaseModel):
|
||||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
|
||||
|
||||
|
||||
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
@@ -89,13 +102,32 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_interpreter_config(
|
||||
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_execution_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_URL
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||
@@ -116,8 +148,17 @@ async def set_code_interpreter_config(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
return {
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
@@ -125,6 +166,7 @@ async def set_code_interpreter_config(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from open_webui.models.files import (
|
||||
Files,
|
||||
)
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.routers.audio import transcribe
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel
|
||||
@@ -67,7 +68,22 @@ def upload_file(
|
||||
)
|
||||
|
||||
try:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
if file.content_type in [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/x-m4a",
|
||||
]:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path)
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -225,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
|
||||
content_type = file.meta.get("content_type")
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename)
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
if content_type == "application/pdf" or filename.lower().endswith(
|
||||
".pdf"
|
||||
):
|
||||
headers["Content-Disposition"] = (
|
||||
f"inline; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
content_type = "application/pdf"
|
||||
elif content_type != "text/plain":
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
return FileResponse(file_path, headers=headers, media_type=content_type)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -266,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
log.info(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -79,7 +85,7 @@ async def create_new_function(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to create a new function: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -183,7 +189,7 @@ async def update_function_by_id(
|
||||
FUNCTIONS[id] = function_module
|
||||
|
||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
|
||||
@@ -299,7 +305,7 @@ async def update_function_valves_by_id(
|
||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -388,7 +394,7 @@ async def update_function_user_valves_by_id(
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
||||
16
backend/open_webui/routers/groups.py
Normal file → Executable file
16
backend/open_webui/routers/groups.py
Normal file → Executable file
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import (
|
||||
@@ -14,7 +14,13 @@ from open_webui.models.groups import (
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
@@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new group: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -94,7 +100,7 @@ async def update_group_by_id(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
@@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error deleting group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
||||
@@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
|
||||
class GeminiConfigForm(BaseModel):
|
||||
GEMINI_API_BASE_URL: str
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
|
||||
class ConfigForm(BaseModel):
|
||||
enabled: bool
|
||||
engine: str
|
||||
@@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
|
||||
openai: OpenAIConfigForm
|
||||
automatic1111: Automatic1111ConfigForm
|
||||
comfyui: ComfyUIConfigForm
|
||||
gemini: GeminiConfigForm
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
@@ -103,6 +113,11 @@ async def update_config(
|
||||
)
|
||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
|
||||
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
||||
form_data.gemini.GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
||||
|
||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||
)
|
||||
@@ -129,6 +144,8 @@ async def update_config(
|
||||
request.app.state.config.COMFYUI_BASE_URL = (
|
||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
)
|
||||
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
|
||||
|
||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
@@ -155,6 +172,10 @@ async def update_config(
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
|
||||
headers = None
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
@@ -224,6 +253,12 @@ def get_image_model(request):
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "dall-e-2"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "imagen-3.0-generate-002"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
@@ -299,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
headers = {
|
||||
@@ -322,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
if model_node_id:
|
||||
model_list_key = None
|
||||
|
||||
print(workflow[model_node_id]["class_type"])
|
||||
log.info(workflow[model_node_id]["class_type"])
|
||||
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
||||
"required"
|
||||
]:
|
||||
@@ -483,6 +522,41 @@ async def image_generations(
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||
|
||||
model = get_image_model(request)
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = load_b64_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
|
||||
@@ -614,7 +614,7 @@ def add_files_to_knowledge_batch(
|
||||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
log.info(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
|
||||
@@ -14,6 +14,11 @@ from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
import requests
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
@@ -26,7 +31,7 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
||||
@@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -96,6 +115,7 @@ async def send_post_request(
|
||||
stream: bool = True,
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
):
|
||||
|
||||
r = None
|
||||
@@ -110,6 +130,16 @@ async def send_post_request(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -191,7 +221,19 @@ async def verify_connection(
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/api/version",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
detail = f"HTTP Error: {r.status}"
|
||||
@@ -254,7 +296,7 @@ async def update_config(
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request):
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
@@ -262,7 +304,7 @@ async def get_all_models(request: Request):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
@@ -275,7 +317,9 @@ async def get_all_models(request: Request):
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/tags", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
@@ -360,7 +404,7 @@ async def get_ollama_tags(
|
||||
models = []
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
@@ -370,7 +414,19 @@ async def get_ollama_tags(
|
||||
r = requests.request(
|
||||
method="GET",
|
||||
url=f"{url}/api/tags",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
user=user,
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
@@ -509,6 +566,7 @@ async def pull_model(
|
||||
url=f"{url}/api/pull",
|
||||
payload=json.dumps(payload),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -527,7 +585,7 @@ async def push_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
@@ -545,6 +603,7 @@ async def push_model(
|
||||
url=f"{url}/api/push",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -571,6 +630,7 @@ async def create_model(
|
||||
url=f"{url}/api/create",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -588,7 +648,7 @@ async def copy_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.source in models:
|
||||
@@ -609,6 +669,16 @@ async def copy_model(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -643,7 +713,7 @@ async def delete_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
@@ -665,6 +735,16 @@ async def delete_model(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
@@ -693,7 +773,7 @@ async def delete_model(
|
||||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name not in models:
|
||||
@@ -714,6 +794,16 @@ async def show_model_info(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -757,7 +847,7 @@ async def embed(
|
||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -783,6 +873,16 @@ async def embed(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -826,7 +926,7 @@ async def embeddings(
|
||||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -852,6 +952,16 @@ async def embeddings(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
@@ -901,7 +1011,7 @@ async def generate_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
@@ -931,15 +1041,29 @@ async def generate_completion(
|
||||
url=f"{url}/api/generate",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
images: Optional[list[str]] = None
|
||||
|
||||
@validator("content", pre=True)
|
||||
@classmethod
|
||||
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||
# Raise an error if both 'content' and 'tool_calls' are None
|
||||
if field_value is None and (
|
||||
"tool_calls" not in values or values["tool_calls"] is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of 'content' or 'tool_calls' must be provided"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
@@ -1047,6 +1171,7 @@ async def generate_chat_completion(
|
||||
stream=form_data.stream,
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1149,6 +1274,7 @@ async def generate_openai_completion(
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1227,6 +1353,7 @@ async def generate_openai_chat_completion(
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1240,7 +1367,7 @@ async def get_openai_models(
|
||||
|
||||
models = []
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models(request)
|
||||
model_list = await get_all_models(request, user=user)
|
||||
models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
|
||||
@@ -26,6 +26,7 @@ from open_webui.env import (
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
@@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -84,9 +98,15 @@ def openai_o1_o3_handler(payload):
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
||||
# Fix: o1 and o3 do not support the "system" role directly.
|
||||
# For older models like "o1-mini" or "o1-preview", use role "user".
|
||||
# For newer o1/o3 models, replace "system" with "developer".
|
||||
if payload["messages"][0]["role"] == "system":
|
||||
payload["messages"][0]["role"] = "user"
|
||||
model_lower = payload["model"].lower()
|
||||
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||
payload["messages"][0]["role"] = "user"
|
||||
else:
|
||||
payload["messages"][0]["role"] = "developer"
|
||||
|
||||
return payload
|
||||
|
||||
@@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||
|
||||
|
||||
async def get_all_models_responses(request: Request) -> list:
|
||||
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
|
||||
@@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
):
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list:
|
||||
send_get_request(
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -352,13 +375,13 @@ async def get_filtered_models(models, user):
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request) -> dict[str, list]:
|
||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user=user)
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
@@ -418,7 +441,7 @@ async def get_models(
|
||||
}
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
@@ -515,6 +538,16 @@ async def verify_connection(
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
@@ -587,7 +620,7 @@ async def generate_chat_completion(
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
@@ -777,7 +810,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
log.error(res)
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
|
||||
@@ -101,7 +101,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
if "detail" in res:
|
||||
raise Exception(response.status, res["detail"])
|
||||
except Exception as e:
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
@@ -153,7 +153,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
@@ -169,7 +169,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("/list")
|
||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user)
|
||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||
|
||||
urlIdxs = [
|
||||
@@ -196,7 +196,7 @@ async def upload_pipeline(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||
# Check if the uploaded file is a python file
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
@@ -231,7 +231,7 @@ async def upload_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
@@ -282,7 +282,7 @@ async def add_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -327,7 +327,7 @@ async def delete_pipeline(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -361,7 +361,7 @@ async def get_pipelines(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -400,7 +400,7 @@ async def get_pipeline_valves(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -440,7 +440,7 @@ async def get_pipeline_valves_spec(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
@@ -482,7 +482,7 @@ async def update_pipeline_valves(
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
|
||||
|
||||
@@ -351,10 +351,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
@@ -371,10 +378,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
},
|
||||
"web": {
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
@@ -397,6 +406,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
@@ -409,9 +419,15 @@ class FileConfig(BaseModel):
|
||||
max_count: Optional[int] = None
|
||||
|
||||
|
||||
class DocumentIntelligenceConfigForm(BaseModel):
|
||||
endpoint: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContentExtractionConfig(BaseModel):
|
||||
engine: str = ""
|
||||
tika_server_url: Optional[str] = None
|
||||
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
|
||||
|
||||
|
||||
class ChunkParamUpdateForm(BaseModel):
|
||||
@@ -457,12 +473,16 @@ class WebSearchConfig(BaseModel):
|
||||
|
||||
class WebConfig(BaseModel):
|
||||
search: WebSearchConfig
|
||||
web_loader_ssl_verification: Optional[bool] = None
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
|
||||
|
||||
class ConfigUpdateForm(BaseModel):
|
||||
RAG_FULL_CONTEXT: Optional[bool] = None
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
pdf_extract_images: Optional[bool] = None
|
||||
enable_google_drive_integration: Optional[bool] = None
|
||||
enable_onedrive_integration: Optional[bool] = None
|
||||
file: Optional[FileConfig] = None
|
||||
content_extraction: Optional[ContentExtractionConfig] = None
|
||||
chunk: Optional[ChunkParamUpdateForm] = None
|
||||
@@ -480,24 +500,51 @@ async def update_rag_config(
|
||||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_FULL_CONTEXT = (
|
||||
form_data.RAG_FULL_CONTEXT
|
||||
if form_data.RAG_FULL_CONTEXT is not None
|
||||
else request.app.state.config.RAG_FULL_CONTEXT
|
||||
)
|
||||
|
||||
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
|
||||
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||||
form_data.enable_google_drive_integration
|
||||
if form_data.enable_google_drive_integration is not None
|
||||
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
|
||||
form_data.enable_onedrive_integration
|
||||
if form_data.enable_onedrive_integration is not None
|
||||
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
if form_data.file is not None:
|
||||
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
||||
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
||||
|
||||
if form_data.content_extraction is not None:
|
||||
log.info(f"Updating text settings: {form_data.content_extraction}")
|
||||
log.info(
|
||||
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
|
||||
)
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||||
form_data.content_extraction.engine
|
||||
)
|
||||
request.app.state.config.TIKA_SERVER_URL = (
|
||||
form_data.content_extraction.tika_server_url
|
||||
)
|
||||
if form_data.content_extraction.document_intelligence_config is not None:
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.content_extraction.document_intelligence_config.endpoint
|
||||
)
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
||||
form_data.content_extraction.document_intelligence_config.key
|
||||
)
|
||||
|
||||
if form_data.chunk is not None:
|
||||
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
||||
@@ -512,11 +559,16 @@ async def update_rag_config(
|
||||
if form_data.web is not None:
|
||||
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
||||
form_data.web.web_loader_ssl_verification
|
||||
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
||||
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
||||
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
|
||||
request.app.state.config.SEARXNG_QUERY_URL = (
|
||||
form_data.web.search.searxng_query_url
|
||||
)
|
||||
@@ -581,6 +633,8 @@ async def update_rag_config(
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"file": {
|
||||
"max_size": request.app.state.config.FILE_MAX_SIZE,
|
||||
"max_count": request.app.state.config.FILE_MAX_COUNT,
|
||||
@@ -588,6 +642,10 @@ async def update_rag_config(
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
@@ -600,7 +658,8 @@ async def update_rag_config(
|
||||
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"web": {
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
@@ -863,7 +922,12 @@ def process_file(
|
||||
# Update the content in the file
|
||||
# Usage: /files/{file_id}/data/content/update
|
||||
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||
try:
|
||||
# /files/{file_id}/data/content/update
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||
except:
|
||||
# Audio file upload pipeline
|
||||
pass
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
@@ -920,6 +984,8 @@ def process_file(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
)
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
@@ -962,36 +1028,45 @@ def process_file(
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.filename,
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.filename,
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
if "No pandoc was found" in str(e):
|
||||
@@ -1262,6 +1337,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
request.app.state.config.TAVILY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||||
@@ -1349,21 +1425,37 @@ async def process_web_search(
|
||||
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = await loader.aload()
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
docs,
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
"filenames": urls,
|
||||
"docs": [
|
||||
{
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
for doc in docs
|
||||
],
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
else:
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
docs,
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
@@ -1520,11 +1612,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"The directory {folder} does not exist")
|
||||
log.warning(f"The directory {folder} does not exist")
|
||||
except Exception as e:
|
||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -208,7 +212,7 @@ async def generate_title(
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 1000}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 1000,
|
||||
}
|
||||
@@ -221,6 +225,12 @@ async def generate_title(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -290,6 +300,12 @@ async def generate_chat_tags(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -356,6 +372,12 @@ async def generate_image_prompt(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -433,6 +455,12 @@ async def generate_queries(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -514,6 +542,12 @@ async def generate_autocompletion(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -571,7 +605,7 @@ async def generate_emoji(
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
@@ -584,6 +618,12 @@ async def generate_emoji(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -644,6 +684,12 @@ async def generate_moa_response(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -111,7 +116,7 @@ async def create_new_tools(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
@@ -193,7 +198,7 @@ async def update_tools_by_id(
|
||||
"specs": specs,
|
||||
}
|
||||
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if tools:
|
||||
@@ -343,7 +348,7 @@ async def update_tools_valves_by_id(
|
||||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
@@ -421,7 +426,7 @@ async def update_tools_user_valves_by_id(
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
||||
@@ -1,48 +1,84 @@
|
||||
import black
|
||||
import logging
|
||||
import markdown
|
||||
|
||||
from open_webui.models.chats import ChatTitleMessagesForm
|
||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
from open_webui.utils.misc import get_gravatar_url
|
||||
from open_webui.utils.pdf_generator import PDFGenerator
|
||||
from open_webui.utils.auth import get_admin_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
):
|
||||
async def get_gravatar(email: str, user=Depends(get_verified_user)):
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
class CodeForm(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
return {"code": form_data.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/code/execute")
|
||||
async def execute_code(
|
||||
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
form_data.code,
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
|
||||
else None
|
||||
),
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
)
|
||||
|
||||
return output
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Code execution engine not supported",
|
||||
)
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
|
||||
@router.post("/markdown")
|
||||
async def get_html_from_markdown(
|
||||
form_data: MarkdownForm,
|
||||
form_data: MarkdownForm, user=Depends(get_verified_user)
|
||||
):
|
||||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
@@ -54,7 +90,7 @@ class ChatForm(BaseModel):
|
||||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatTitleMessagesForm,
|
||||
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||
@@ -65,7 +101,7 @@ async def download_chat_as_pdf(
|
||||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating PDF: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, Tuple
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from open_webui.config import (
|
||||
S3_ACCESS_KEY_ID,
|
||||
@@ -13,14 +15,27 @@ from open_webui.config import (
|
||||
S3_KEY_PREFIX,
|
||||
S3_REGION_NAME,
|
||||
S3_SECRET_ACCESS_KEY,
|
||||
S3_USE_ACCELERATE_ENDPOINT,
|
||||
S3_ADDRESSING_STYLE,
|
||||
GCS_BUCKET_NAME,
|
||||
GOOGLE_APPLICATION_CREDENTIALS_JSON,
|
||||
AZURE_STORAGE_ENDPOINT,
|
||||
AZURE_STORAGE_CONTAINER_NAME,
|
||||
AZURE_STORAGE_KEY,
|
||||
STORAGE_PROVIDER,
|
||||
UPLOAD_DIR,
|
||||
)
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import GoogleCloudError, NotFound
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
class StorageProvider(ABC):
|
||||
@@ -65,7 +80,7 @@ class LocalStorageProvider(StorageProvider):
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
else:
|
||||
print(f"File {file_path} not found in local storage.")
|
||||
log.warning(f"File {file_path} not found in local storage.")
|
||||
|
||||
@staticmethod
|
||||
def delete_all_files() -> None:
|
||||
@@ -79,9 +94,9 @@ class LocalStorageProvider(StorageProvider):
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"Directory {UPLOAD_DIR} not found in local storage.")
|
||||
log.warning(f"Directory {UPLOAD_DIR} not found in local storage.")
|
||||
|
||||
|
||||
class S3StorageProvider(StorageProvider):
|
||||
@@ -92,6 +107,12 @@ class S3StorageProvider(StorageProvider):
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
config=Config(
|
||||
s3={
|
||||
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
|
||||
"addressing_style": S3_ADDRESSING_STYLE,
|
||||
},
|
||||
),
|
||||
)
|
||||
self.bucket_name = S3_BUCKET_NAME
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
@@ -221,6 +242,74 @@ class GCSStorageProvider(StorageProvider):
|
||||
LocalStorageProvider.delete_all_files()
|
||||
|
||||
|
||||
class AzureStorageProvider(StorageProvider):
|
||||
def __init__(self):
|
||||
self.endpoint = AZURE_STORAGE_ENDPOINT
|
||||
self.container_name = AZURE_STORAGE_CONTAINER_NAME
|
||||
storage_key = AZURE_STORAGE_KEY
|
||||
|
||||
if storage_key:
|
||||
# Configure using the Azure Storage Account Endpoint and Key
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=self.endpoint, credential=storage_key
|
||||
)
|
||||
else:
|
||||
# Configure using the Azure Storage Account Endpoint and DefaultAzureCredential
|
||||
# If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=self.endpoint, credential=DefaultAzureCredential()
|
||||
)
|
||||
self.container_client = self.blob_service_client.get_container_client(
|
||||
self.container_name
|
||||
)
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to Azure Blob Storage."""
|
||||
contents, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
blob_client.upload_blob(contents, overwrite=True)
|
||||
return contents, f"{self.endpoint}/{self.container_name}/{filename}"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}")
|
||||
|
||||
def get_file(self, file_path: str) -> str:
|
||||
"""Handles downloading of the file from Azure Blob Storage."""
|
||||
try:
|
||||
filename = file_path.split("/")[-1]
|
||||
local_file_path = f"{UPLOAD_DIR}/{filename}"
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
with open(local_file_path, "wb") as download_file:
|
||||
download_file.write(blob_client.download_blob().readall())
|
||||
return local_file_path
|
||||
except ResourceNotFoundError as e:
|
||||
raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}")
|
||||
|
||||
def delete_file(self, file_path: str) -> None:
|
||||
"""Handles deletion of the file from Azure Blob Storage."""
|
||||
try:
|
||||
filename = file_path.split("/")[-1]
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
blob_client.delete_blob()
|
||||
except ResourceNotFoundError as e:
|
||||
raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}")
|
||||
|
||||
# Always delete from local storage
|
||||
LocalStorageProvider.delete_file(file_path)
|
||||
|
||||
def delete_all_files(self) -> None:
|
||||
"""Handles deletion of all files from Azure Blob Storage."""
|
||||
try:
|
||||
blobs = self.container_client.list_blobs()
|
||||
for blob in blobs:
|
||||
self.container_client.delete_blob(blob.name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}")
|
||||
|
||||
# Always delete from local storage
|
||||
LocalStorageProvider.delete_all_files()
|
||||
|
||||
|
||||
def get_storage_provider(storage_provider: str):
|
||||
if storage_provider == "local":
|
||||
Storage = LocalStorageProvider()
|
||||
@@ -228,6 +317,8 @@ def get_storage_provider(storage_provider: str):
|
||||
Storage = S3StorageProvider()
|
||||
elif storage_provider == "gcs":
|
||||
Storage = GCSStorageProvider()
|
||||
elif storage_provider == "azure":
|
||||
Storage = AzureStorageProvider()
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
|
||||
return Storage
|
||||
|
||||
@@ -7,6 +7,8 @@ from moto import mock_aws
|
||||
from open_webui.storage import provider
|
||||
from gcp_storage_emulator.server import create_server
|
||||
from google.cloud import storage
|
||||
from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def mock_upload_dir(monkeypatch, tmp_path):
|
||||
@@ -22,6 +24,7 @@ def test_imports():
|
||||
provider.LocalStorageProvider
|
||||
provider.S3StorageProvider
|
||||
provider.GCSStorageProvider
|
||||
provider.AzureStorageProvider
|
||||
provider.Storage
|
||||
|
||||
|
||||
@@ -32,6 +35,8 @@ def test_get_storage_provider():
|
||||
assert isinstance(Storage, provider.S3StorageProvider)
|
||||
Storage = provider.get_storage_provider("gcs")
|
||||
assert isinstance(Storage, provider.GCSStorageProvider)
|
||||
Storage = provider.get_storage_provider("azure")
|
||||
assert isinstance(Storage, provider.AzureStorageProvider)
|
||||
with pytest.raises(RuntimeError):
|
||||
provider.get_storage_provider("invalid")
|
||||
|
||||
@@ -48,6 +53,7 @@ def test_class_instantiation():
|
||||
provider.LocalStorageProvider()
|
||||
provider.S3StorageProvider()
|
||||
provider.GCSStorageProvider()
|
||||
provider.AzureStorageProvider()
|
||||
|
||||
|
||||
class TestLocalStorageProvider:
|
||||
@@ -272,3 +278,147 @@ class TestGCSStorageProvider:
|
||||
assert not (upload_dir / self.filename_extra).exists()
|
||||
assert self.Storage.bucket.get_blob(self.filename) == None
|
||||
assert self.Storage.bucket.get_blob(self.filename_extra) == None
|
||||
|
||||
|
||||
class TestAzureStorageProvider:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup_storage(self, monkeypatch):
|
||||
# Create mock Blob Service Client and related clients
|
||||
mock_blob_service_client = MagicMock()
|
||||
mock_container_client = MagicMock()
|
||||
mock_blob_client = MagicMock()
|
||||
|
||||
# Set up return values for the mock
|
||||
mock_blob_service_client.get_container_client.return_value = (
|
||||
mock_container_client
|
||||
)
|
||||
mock_container_client.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
# Monkeypatch the Azure classes to return our mocks
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob,
|
||||
"BlobServiceClient",
|
||||
lambda *args, **kwargs: mock_blob_service_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob,
|
||||
"ContainerClient",
|
||||
lambda *args, **kwargs: mock_container_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
|
||||
)
|
||||
|
||||
self.Storage = provider.AzureStorageProvider()
|
||||
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
|
||||
self.Storage.container_name = "my-container"
|
||||
self.file_content = b"test content"
|
||||
self.filename = "test.txt"
|
||||
self.filename_extra = "test_extra.txt"
|
||||
self.file_bytesio_empty = io.BytesIO()
|
||||
|
||||
# Apply mocks to the Storage instance
|
||||
self.Storage.blob_service_client = mock_blob_service_client
|
||||
self.Storage.container_client = mock_container_client
|
||||
|
||||
def test_upload_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
|
||||
# Simulate an error when container does not exist
|
||||
self.Storage.container_client.get_blob_client.side_effect = Exception(
|
||||
"Container does not exist"
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
|
||||
# Reset side effect and create container
|
||||
self.Storage.container_client.get_blob_client.side_effect = None
|
||||
self.Storage.create_container()
|
||||
contents, azure_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
|
||||
# Assertions
|
||||
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
|
||||
self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
|
||||
self.file_content, overwrite=True
|
||||
)
|
||||
assert contents == self.file_content
|
||||
assert (
|
||||
azure_file_path
|
||||
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
|
||||
|
||||
def test_get_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.Storage.create_container()
|
||||
|
||||
# Mock upload behavior
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
# Mock blob download behavior
|
||||
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
|
||||
self.file_content
|
||||
)
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
file_path = self.Storage.get_file(file_url)
|
||||
|
||||
assert file_path == str(upload_dir / self.filename)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
|
||||
def test_delete_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.Storage.create_container()
|
||||
|
||||
# Mock file upload
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
# Mock deletion
|
||||
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
self.Storage.delete_file(file_url)
|
||||
|
||||
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
|
||||
def test_delete_all_files(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.Storage.create_container()
|
||||
|
||||
# Mock file uploads
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
|
||||
|
||||
# Mock listing and deletion behavior
|
||||
self.Storage.container_client.list_blobs.return_value = [
|
||||
{"name": self.filename},
|
||||
{"name": self.filename_extra},
|
||||
]
|
||||
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||
|
||||
self.Storage.delete_all_files()
|
||||
|
||||
self.Storage.container_client.list_blobs.assert_called_once()
|
||||
self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
assert not (upload_dir / self.filename_extra).exists()
|
||||
|
||||
def test_get_file_not_found(self, monkeypatch):
|
||||
self.Storage.create_container()
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
# Mock behavior to raise an error for missing blobs
|
||||
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
|
||||
Exception("Blob not found")
|
||||
)
|
||||
with pytest.raises(Exception, match="Blob not found"):
|
||||
self.Storage.get_file(file_url)
|
||||
|
||||
249
backend/open_webui/utils/audit.py
Normal file
249
backend/open_webui/utils/audit.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from asgiref.typing import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
Scope as ASGIScope,
|
||||
)
|
||||
from loguru import logger
|
||||
from starlette.requests import Request
|
||||
|
||||
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
|
||||
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Logger
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuditLogEntry:
|
||||
# `Metadata` audit level properties
|
||||
id: str
|
||||
user: dict[str, Any]
|
||||
audit_level: str
|
||||
verb: str
|
||||
request_uri: str
|
||||
user_agent: Optional[str] = None
|
||||
source_ip: Optional[str] = None
|
||||
# `Request` audit level properties
|
||||
request_object: Any = None
|
||||
# `Request Response` level
|
||||
response_object: Any = None
|
||||
response_status_code: Optional[int] = None
|
||||
|
||||
|
||||
class AuditLevel(str, Enum):
|
||||
NONE = "NONE"
|
||||
METADATA = "METADATA"
|
||||
REQUEST = "REQUEST"
|
||||
REQUEST_RESPONSE = "REQUEST_RESPONSE"
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
|
||||
|
||||
Parameters:
|
||||
logger (Logger): An instance of Loguru’s logger.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: "Logger"):
|
||||
self.logger = logger.bind(auditable=True)
|
||||
|
||||
def write(
|
||||
self,
|
||||
audit_entry: AuditLogEntry,
|
||||
*,
|
||||
log_level: str = "INFO",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
|
||||
entry = asdict(audit_entry)
|
||||
|
||||
if extra:
|
||||
entry["extra"] = extra
|
||||
|
||||
self.logger.log(
|
||||
log_level,
|
||||
"",
|
||||
**entry,
|
||||
)
|
||||
|
||||
|
||||
class AuditContext:
|
||||
"""
|
||||
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
|
||||
|
||||
Attributes:
|
||||
request_body (bytearray): Accumulated request payload.
|
||||
response_body (bytearray): Accumulated response payload.
|
||||
max_body_size (int): Maximum number of bytes to capture.
|
||||
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
|
||||
self.request_body = bytearray()
|
||||
self.response_body = bytearray()
|
||||
self.max_body_size = max_body_size
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
|
||||
def add_request_chunk(self, chunk: bytes):
|
||||
if len(self.request_body) < self.max_body_size:
|
||||
self.request_body.extend(
|
||||
chunk[: self.max_body_size - len(self.request_body)]
|
||||
)
|
||||
|
||||
def add_response_chunk(self, chunk: bytes):
|
||||
if len(self.response_body) < self.max_body_size:
|
||||
self.response_body.extend(
|
||||
chunk[: self.max_body_size - len(self.response_body)]
|
||||
)
|
||||
|
||||
|
||||
class AuditLoggingMiddleware:
|
||||
"""
|
||||
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
|
||||
"""
|
||||
|
||||
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGI3Application,
|
||||
*,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
max_body_size: int = MAX_BODY_LOG_SIZE,
|
||||
audit_level: AuditLevel = AuditLevel.NONE,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.audit_logger = AuditLogger(logger)
|
||||
self.excluded_paths = excluded_paths or []
|
||||
self.max_body_size = max_body_size
|
||||
self.audit_level = audit_level
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: ASGIScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
request = Request(scope=cast(MutableMapping, scope))
|
||||
|
||||
if self._should_skip_auditing(request):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async with self._audit_context(request) as context:
|
||||
|
||||
async def send_wrapper(message: ASGISendEvent) -> None:
|
||||
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
|
||||
await self._capture_response(message, context)
|
||||
|
||||
await send(message)
|
||||
|
||||
original_receive = receive
|
||||
|
||||
async def receive_wrapper() -> ASGIReceiveEvent:
|
||||
nonlocal original_receive
|
||||
message = await original_receive()
|
||||
|
||||
if self.audit_level in (
|
||||
AuditLevel.REQUEST,
|
||||
AuditLevel.REQUEST_RESPONSE,
|
||||
):
|
||||
await self._capture_request(message, context)
|
||||
|
||||
return message
|
||||
|
||||
await self.app(scope, receive_wrapper, send_wrapper)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _audit_context(
|
||||
self, request: Request
|
||||
) -> AsyncGenerator[AuditContext, None]:
|
||||
"""
|
||||
async context manager that ensures that an audit log entry is recorded after the request is processed.
|
||||
"""
|
||||
context = AuditContext()
|
||||
try:
|
||||
yield context
|
||||
finally:
|
||||
await self._log_audit_entry(request, context)
|
||||
|
||||
async def _get_authenticated_user(self, request: Request) -> UserModel:
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
assert auth_header
|
||||
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
|
||||
|
||||
return user
|
||||
|
||||
def _should_skip_auditing(self, request: Request) -> bool:
|
||||
if (
|
||||
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
|
||||
or AUDIT_LOG_LEVEL == "NONE"
|
||||
or not request.headers.get("authorization")
|
||||
):
|
||||
return True
|
||||
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
|
||||
pattern = re.compile(
|
||||
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
|
||||
)
|
||||
if pattern.match(request.url.path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
|
||||
if message["type"] == "http.request":
|
||||
body = message.get("body", b"")
|
||||
context.add_request_chunk(body)
|
||||
|
||||
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
|
||||
if message["type"] == "http.response.start":
|
||||
context.metadata["response_status_code"] = message["status"]
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
context.add_response_chunk(body)
|
||||
|
||||
async def _log_audit_entry(self, request: Request, context: AuditContext):
|
||||
try:
|
||||
user = await self._get_authenticated_user(request)
|
||||
|
||||
entry = AuditLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
user=user.model_dump(include={"id", "name", "email", "role"}),
|
||||
audit_level=self.audit_level.value,
|
||||
verb=request.method,
|
||||
request_uri=str(request.url),
|
||||
response_status_code=context.metadata.get("response_status_code", None),
|
||||
source_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
request_object=context.request_body.decode("utf-8", errors="replace"),
|
||||
response_object=context.response_body.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
self.audit_logger.write(entry)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit entry: {str(e)}")
|
||||
@@ -5,6 +5,7 @@ import base64
|
||||
import hmac
|
||||
import hashlib
|
||||
import requests
|
||||
import os
|
||||
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@@ -13,15 +14,22 @@ from typing import Optional, Union, List, Dict
|
||||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import override_static
|
||||
from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
|
||||
from open_webui.env import (
|
||||
WEBUI_SECRET_KEY,
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
STATIC_DIR,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from passlib.context import CryptContext
|
||||
|
||||
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||
|
||||
SESSION_SECRET = WEBUI_SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
@@ -47,6 +55,19 @@ def verify_signature(payload: str, signature: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def override_static(path: str, content: str):
|
||||
# Ensure path is safe
|
||||
if "/" in path or ".." in path:
|
||||
log.error(f"Invalid path: {path}")
|
||||
return
|
||||
|
||||
file_path = os.path.join(STATIC_DIR, path)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
|
||||
|
||||
|
||||
def get_license_data(app, key):
|
||||
if key:
|
||||
try:
|
||||
@@ -69,11 +90,11 @@ def get_license_data(app, key):
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
log.error(
|
||||
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
|
||||
)
|
||||
except Exception as ex:
|
||||
print(f"License: Uncaught Exception: {ex}")
|
||||
log.exception(f"License: Uncaught Exception: {ex}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -129,6 +150,7 @@ def get_http_authorization_cred(auth_header: str):
|
||||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
token = None
|
||||
@@ -181,7 +203,10 @@ def get_current_user(
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -66,7 +66,7 @@ async def generate_direct_chat_completion(
|
||||
user: Any,
|
||||
models: dict,
|
||||
):
|
||||
print("generate_direct_chat_completion")
|
||||
log.info("generate_direct_chat_completion")
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
@@ -103,7 +103,7 @@ async def generate_direct_chat_completion(
|
||||
}
|
||||
)
|
||||
|
||||
print("res", res)
|
||||
log.info(f"res: {res}")
|
||||
|
||||
if res.get("status", False):
|
||||
# Define a generator to stream responses
|
||||
@@ -200,7 +200,7 @@ async def generate_chat_completion(
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
if model.get("owned_by") == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
@@ -253,7 +253,7 @@ async def generate_chat_completion(
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
if model.get("owned_by") == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
@@ -285,7 +285,7 @@ chat_completion = generate_chat_completion
|
||||
|
||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
@@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
@@ -432,7 +432,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to get user values: {e}")
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_sorted_filter_ids(model):
|
||||
@@ -61,7 +67,12 @@ async def process_filter_functions(
|
||||
try:
|
||||
# Prepare parameters
|
||||
sig = inspect.signature(handler)
|
||||
params = {"body": form_data} | {
|
||||
|
||||
params = {"body": form_data}
|
||||
if filter_type == "stream":
|
||||
params = {"event": form_data}
|
||||
|
||||
params = params | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
@@ -80,7 +91,7 @@ async def process_filter_functions(
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to get user values: {e}")
|
||||
|
||||
# Execute handler
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
@@ -89,7 +100,7 @@ async def process_filter_functions(
|
||||
form_data = handler(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||
log.exception(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||
raise e
|
||||
|
||||
# Handle file cleanup for inlet
|
||||
|
||||
140
backend/open_webui/utils/logger.py
Normal file
140
backend/open_webui/utils/logger.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from open_webui.env import (
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
AUDIT_LOG_LEVEL,
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Record
|
||||
|
||||
|
||||
def stdout_format(record: "Record") -> str:
|
||||
"""
|
||||
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
|
||||
|
||||
Parameters:
|
||||
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
|
||||
Returns:
|
||||
str: A formatted log string intended for stdout.
|
||||
"""
|
||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||
return (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level> - {extra[extra_json]}"
|
||||
"\n{exception}"
|
||||
)
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""
|
||||
Intercepts log records from Python's standard logging module
|
||||
and redirects them to Loguru's logger.
|
||||
"""
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Called by the standard logging module for each log event.
|
||||
It transforms the standard `LogRecord` into a format compatible with Loguru
|
||||
and passes it to Loguru's logger.
|
||||
"""
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def file_format(record: "Record"):
|
||||
"""
|
||||
Formats audit log records into a structured JSON string for file output.
|
||||
|
||||
Parameters:
|
||||
record (Record): A Loguru record containing extra audit data.
|
||||
Returns:
|
||||
str: A JSON-formatted string representing the audit data.
|
||||
"""
|
||||
|
||||
audit_data = {
|
||||
"id": record["extra"].get("id", ""),
|
||||
"timestamp": int(record["time"].timestamp()),
|
||||
"user": record["extra"].get("user", dict()),
|
||||
"audit_level": record["extra"].get("audit_level", ""),
|
||||
"verb": record["extra"].get("verb", ""),
|
||||
"request_uri": record["extra"].get("request_uri", ""),
|
||||
"response_status_code": record["extra"].get("response_status_code", 0),
|
||||
"source_ip": record["extra"].get("source_ip", ""),
|
||||
"user_agent": record["extra"].get("user_agent", ""),
|
||||
"request_object": record["extra"].get("request_object", b""),
|
||||
"response_object": record["extra"].get("response_object", b""),
|
||||
"extra": record["extra"].get("extra", {}),
|
||||
}
|
||||
|
||||
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
|
||||
return "{extra[file_extra]}\n"
|
||||
|
||||
|
||||
def start_logger():
|
||||
"""
|
||||
Initializes and configures Loguru's logger with distinct handlers:
|
||||
|
||||
A console (stdout) handler for general log messages (excluding those marked as auditable).
|
||||
An optional file handler for audit logs if audit logging is enabled.
|
||||
Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn.
|
||||
|
||||
Parameters:
|
||||
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
|
||||
"""
|
||||
logger.remove()
|
||||
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=GLOBAL_LOG_LEVEL,
|
||||
format=stdout_format,
|
||||
filter=lambda record: "auditable" not in record["extra"],
|
||||
)
|
||||
|
||||
if AUDIT_LOG_LEVEL != "NONE":
|
||||
try:
|
||||
logger.add(
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
level="INFO",
|
||||
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
compression="zip",
|
||||
format=file_format,
|
||||
filter=lambda record: record["extra"].get("auditable") is True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
|
||||
|
||||
logging.basicConfig(
|
||||
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
|
||||
)
|
||||
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = []
|
||||
for uvicorn_logger_name in ["uvicorn.access"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = [InterceptHandler()]
|
||||
|
||||
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||
@@ -322,78 +322,95 @@ async def chat_web_search_handler(
|
||||
)
|
||||
return form_data
|
||||
|
||||
searchQuery = queries[0]
|
||||
all_results = []
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
results = await process_web_search(
|
||||
request,
|
||||
SearchForm(
|
||||
**{
|
||||
for searchQuery in queries:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
}
|
||||
),
|
||||
user,
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if results:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Searched {{count}} sites",
|
||||
try:
|
||||
results = await process_web_search(
|
||||
request,
|
||||
SearchForm(
|
||||
**{
|
||||
"query": searchQuery,
|
||||
"urls": results["filenames"],
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
|
||||
files = form_data.get("files", [])
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search_results",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
form_data["files"] = files
|
||||
else:
|
||||
if results:
|
||||
all_results.append(results)
|
||||
files = form_data.get("files", [])
|
||||
|
||||
if results.get("collection_name"):
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
elif results.get("docs"):
|
||||
files.append(
|
||||
{
|
||||
"docs": results.get("docs", []),
|
||||
"name": searchQuery,
|
||||
"type": "web_search",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
|
||||
form_data["files"] = files
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search results found",
|
||||
"description": 'Error searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
if all_results:
|
||||
urls = []
|
||||
for results in all_results:
|
||||
if "filenames" in results:
|
||||
urls.extend(results["filenames"])
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Error searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"description": "Searched {{count}} sites",
|
||||
"urls": urls,
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search results found",
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
@@ -503,6 +520,7 @@ async def chat_completion_files_handler(
|
||||
sources = []
|
||||
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
queries = []
|
||||
try:
|
||||
queries_response = await generate_queries(
|
||||
request,
|
||||
@@ -528,8 +546,8 @@ async def chat_completion_files_handler(
|
||||
queries_response = {"queries": [queries_response]}
|
||||
|
||||
queries = queries_response.get("queries", [])
|
||||
except Exception as e:
|
||||
queries = []
|
||||
except:
|
||||
pass
|
||||
|
||||
if len(queries) == 0:
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
@@ -541,6 +559,7 @@ async def chat_completion_files_handler(
|
||||
sources = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: get_sources_from_files(
|
||||
request=request,
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
@@ -550,9 +569,9 @@ async def chat_completion_files_handler(
|
||||
reranking_function=request.app.state.rf,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
full_context=request.app.state.config.RAG_FULL_CONTEXT,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
@@ -728,6 +747,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
tool_ids = form_data.pop("tool_ids", None)
|
||||
files = form_data.pop("files", None)
|
||||
|
||||
# Remove files duplicates
|
||||
if files:
|
||||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||||
@@ -785,8 +805,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
if len(sources) > 0:
|
||||
context_string = ""
|
||||
for source_idx, source in enumerate(sources):
|
||||
source_id = source.get("source", {}).get("name", "")
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
@@ -806,7 +824,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model["owned_by"] == "ollama":
|
||||
if model.get("owned_by") == "ollama":
|
||||
form_data["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(
|
||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
@@ -1038,6 +1056,21 @@ async def process_chat_response(
|
||||
):
|
||||
return response
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": metadata.get("model"),
|
||||
}
|
||||
filter_ids = get_sorted_filter_ids(form_data.get("model"))
|
||||
|
||||
# Streaming response
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
@@ -1117,12 +1150,12 @@ async def process_chat_response(
|
||||
|
||||
if reasoning_duration is not None:
|
||||
if raw:
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
|
||||
@@ -1218,9 +1251,9 @@ async def process_chat_response(
|
||||
return attributes
|
||||
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
for tag in tags:
|
||||
for start_tag, end_tag in tags:
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
start_tag_pattern = rf"<{tag}(\s.*?)?>"
|
||||
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>"
|
||||
match = re.search(start_tag_pattern, content)
|
||||
if match:
|
||||
attr_content = (
|
||||
@@ -1253,7 +1286,8 @@ async def process_chat_response(
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": content_type,
|
||||
"tag": tag,
|
||||
"start_tag": start_tag,
|
||||
"end_tag": end_tag,
|
||||
"attributes": attributes,
|
||||
"content": "",
|
||||
"started_at": time.time(),
|
||||
@@ -1265,9 +1299,10 @@ async def process_chat_response(
|
||||
|
||||
break
|
||||
elif content_blocks[-1]["type"] == content_type:
|
||||
tag = content_blocks[-1]["tag"]
|
||||
start_tag = content_blocks[-1]["start_tag"]
|
||||
end_tag = content_blocks[-1]["end_tag"]
|
||||
# Match end tag e.g., </tag>
|
||||
end_tag_pattern = rf"</{tag}>"
|
||||
end_tag_pattern = rf"<{re.escape(end_tag)}>"
|
||||
|
||||
# Check if the content has the end tag
|
||||
if re.search(end_tag_pattern, content):
|
||||
@@ -1275,7 +1310,7 @@ async def process_chat_response(
|
||||
|
||||
block_content = content_blocks[-1]["content"]
|
||||
# Strip start and end tags from the content
|
||||
start_tag_pattern = rf"<{tag}(.*?)>"
|
||||
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
|
||||
block_content = re.sub(
|
||||
start_tag_pattern, "", block_content
|
||||
).strip()
|
||||
@@ -1340,7 +1375,7 @@ async def process_chat_response(
|
||||
|
||||
# Clean processed content
|
||||
content = re.sub(
|
||||
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
||||
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
@@ -1353,7 +1388,22 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
content = message.get("content", "") if message else ""
|
||||
|
||||
last_assistant_message = None
|
||||
try:
|
||||
if form_data["messages"][-1]["role"] == "assistant":
|
||||
last_assistant_message = get_last_assistant_message(
|
||||
form_data["messages"]
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
content = (
|
||||
message.get("content", "")
|
||||
if message
|
||||
else last_assistant_message if last_assistant_message else ""
|
||||
)
|
||||
|
||||
content_blocks = [
|
||||
{
|
||||
"type": "text",
|
||||
@@ -1363,19 +1413,24 @@ async def process_chat_response(
|
||||
|
||||
# We might want to disable this by default
|
||||
DETECT_REASONING = True
|
||||
DETECT_SOLUTION = True
|
||||
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
|
||||
"code_interpreter", False
|
||||
)
|
||||
|
||||
reasoning_tags = [
|
||||
"think",
|
||||
"thinking",
|
||||
"reason",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"Thought",
|
||||
("think", "/think"),
|
||||
("thinking", "/thinking"),
|
||||
("reason", "/reason"),
|
||||
("reasoning", "/reasoning"),
|
||||
("thought", "/thought"),
|
||||
("Thought", "/Thought"),
|
||||
("|begin_of_thought|", "|end_of_thought|"),
|
||||
]
|
||||
code_interpreter_tags = ["code_interpreter"]
|
||||
|
||||
code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
|
||||
|
||||
solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
@@ -1419,119 +1474,154 @@ async def process_chat_response(
|
||||
try:
|
||||
data = json.loads(data)
|
||||
|
||||
if "selected_model_id" in data:
|
||||
model_id = data["selected_model_id"]
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": model_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=filter_ids,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
delta_tool_calls = delta.get("tool_calls", None)
|
||||
|
||||
if delta_tool_calls:
|
||||
for delta_tool_call in delta_tool_calls:
|
||||
tool_call_index = delta_tool_call.get("index")
|
||||
|
||||
if tool_call_index is not None:
|
||||
if (
|
||||
len(response_tool_calls)
|
||||
<= tool_call_index
|
||||
):
|
||||
response_tool_calls.append(
|
||||
delta_tool_call
|
||||
)
|
||||
else:
|
||||
delta_name = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("name")
|
||||
delta_arguments = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("arguments")
|
||||
|
||||
if delta_name:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"]["name"] += delta_name
|
||||
|
||||
if delta_arguments:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"][
|
||||
"arguments"
|
||||
] += delta_arguments
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["content"] = (
|
||||
content_blocks[-1]["content"] + value
|
||||
if data:
|
||||
if "selected_model_id" in data:
|
||||
model_id = data["selected_model_id"]
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"selectedModelId": model_id,
|
||||
},
|
||||
)
|
||||
|
||||
if DETECT_REASONING:
|
||||
content, content_blocks, _ = (
|
||||
tag_content_handler(
|
||||
"reasoning",
|
||||
reasoning_tags,
|
||||
content,
|
||||
content_blocks,
|
||||
else:
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
usage = data.get("usage", {})
|
||||
if usage:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"usage": usage,
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
delta_tool_calls = delta.get("tool_calls", None)
|
||||
|
||||
if delta_tool_calls:
|
||||
for delta_tool_call in delta_tool_calls:
|
||||
tool_call_index = delta_tool_call.get(
|
||||
"index"
|
||||
)
|
||||
|
||||
if tool_call_index is not None:
|
||||
if (
|
||||
len(response_tool_calls)
|
||||
<= tool_call_index
|
||||
):
|
||||
response_tool_calls.append(
|
||||
delta_tool_call
|
||||
)
|
||||
else:
|
||||
delta_name = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("name")
|
||||
delta_arguments = (
|
||||
delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("arguments")
|
||||
)
|
||||
|
||||
if delta_name:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"][
|
||||
"name"
|
||||
] += delta_name
|
||||
|
||||
if delta_arguments:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"][
|
||||
"arguments"
|
||||
] += delta_arguments
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["content"] = (
|
||||
content_blocks[-1]["content"] + value
|
||||
)
|
||||
|
||||
if DETECT_CODE_INTERPRETER:
|
||||
content, content_blocks, end = (
|
||||
tag_content_handler(
|
||||
"code_interpreter",
|
||||
code_interpreter_tags,
|
||||
content,
|
||||
content_blocks,
|
||||
if DETECT_REASONING:
|
||||
content, content_blocks, _ = (
|
||||
tag_content_handler(
|
||||
"reasoning",
|
||||
reasoning_tags,
|
||||
content,
|
||||
content_blocks,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if end:
|
||||
break
|
||||
if DETECT_CODE_INTERPRETER:
|
||||
content, content_blocks, end = (
|
||||
tag_content_handler(
|
||||
"code_interpreter",
|
||||
code_interpreter_tags,
|
||||
content,
|
||||
content_blocks,
|
||||
)
|
||||
)
|
||||
|
||||
if ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
if end:
|
||||
break
|
||||
|
||||
if DETECT_SOLUTION:
|
||||
content, content_blocks, _ = (
|
||||
tag_content_handler(
|
||||
"solution",
|
||||
solution_tags,
|
||||
content,
|
||||
content_blocks,
|
||||
)
|
||||
)
|
||||
|
||||
if ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = {
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = {
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
done = "data: [DONE]" in line
|
||||
if done:
|
||||
@@ -1736,6 +1826,7 @@ async def process_chat_response(
|
||||
== "password"
|
||||
else None
|
||||
),
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
output = {
|
||||
@@ -1829,7 +1920,10 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
print(content_blocks, serialize_content_blocks(content_blocks))
|
||||
log.info(f"content_blocks={content_blocks}")
|
||||
log.info(
|
||||
f"serialize_content_blocks={serialize_content_blocks(content_blocks)}"
|
||||
)
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
@@ -1900,7 +1994,7 @@ async def process_chat_response(
|
||||
|
||||
await background_tasks_handler()
|
||||
except asyncio.CancelledError:
|
||||
print("Task was cancelled!")
|
||||
log.warning("Task was cancelled!")
|
||||
await event_emitter({"type": "task-cancelled"})
|
||||
|
||||
if not ENABLE_REALTIME_CHAT_SAVE:
|
||||
@@ -1921,17 +2015,34 @@ async def process_chat_response(
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
else:
|
||||
|
||||
# Fallback to the original response
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n"
|
||||
|
||||
for event in events:
|
||||
yield wrap_item(json.dumps(event))
|
||||
event, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=filter_ids,
|
||||
filter_type="stream",
|
||||
form_data=event,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
|
||||
if event:
|
||||
yield wrap_item(json.dumps(event))
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=filter_ids,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
|
||||
if data:
|
||||
yield data
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
|
||||
@@ -2,6 +2,7 @@ import hashlib
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
@@ -9,6 +10,10 @@ import json
|
||||
|
||||
|
||||
import collections.abc
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
@@ -413,7 +418,7 @@ def parse_ollama_modelfile(model_text):
|
||||
elif param_type is bool:
|
||||
value = value.lower() == "true"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to parse parameter {param}: {e}")
|
||||
continue
|
||||
|
||||
data["params"][param] = value
|
||||
|
||||
@@ -22,6 +22,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
@@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def get_all_base_models(request: Request):
|
||||
async def get_all_base_models(request: Request, user: UserModel = None):
|
||||
function_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
|
||||
if request.app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await openai.get_all_models(request)
|
||||
openai_models = await openai.get_all_models(request, user=user)
|
||||
openai_models = openai_models["data"]
|
||||
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
ollama_models = await ollama.get_all_models(request)
|
||||
ollama_models = await ollama.get_all_models(request, user=user)
|
||||
ollama_models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
@@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
|
||||
return models
|
||||
|
||||
|
||||
async def get_all_models(request):
|
||||
models = await get_all_base_models(request)
|
||||
async def get_all_models(request, user: UserModel = None):
|
||||
models = await get_all_base_models(request, user=user)
|
||||
|
||||
# If there are no models, return an empty list
|
||||
if len(models) == 0:
|
||||
@@ -142,7 +143,7 @@ async def get_all_models(request):
|
||||
custom_model.base_model_id == model["id"]
|
||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||
):
|
||||
owned_by = model["owned_by"]
|
||||
owned_by = model.get("owned_by", "unknown owner")
|
||||
if "pipe" in model:
|
||||
pipe = model["pipe"]
|
||||
break
|
||||
|
||||
@@ -140,7 +140,14 @@ class OAuthManager:
|
||||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
|
||||
# Nested claim search for groups claim
|
||||
if oauth_claim:
|
||||
claim_data = user_data
|
||||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
user_oauth_groups = claim_data if isinstance(claim_data, list) else []
|
||||
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
@@ -239,11 +246,46 @@ class OAuthManager:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||
email = user_data.get(email_claim, "").lower()
|
||||
email = user_data.get(email_claim, "")
|
||||
# We currently mandate that email addresses are provided
|
||||
if not email:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
# If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
||||
if provider == "github":
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user/emails", headers=headers
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
emails = await resp.json()
|
||||
# use the primary email as the user's email
|
||||
primary_email = next(
|
||||
(e["email"] for e in emails if e.get("primary")),
|
||||
None,
|
||||
)
|
||||
if primary_email:
|
||||
email = primary_email
|
||||
else:
|
||||
log.warning(
|
||||
"No primary email found in GitHub response"
|
||||
)
|
||||
raise HTTPException(
|
||||
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||
)
|
||||
else:
|
||||
log.warning("Failed to fetch GitHub email")
|
||||
raise HTTPException(
|
||||
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(f"Error fetching GitHub email: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
else:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
email = email.lower()
|
||||
if (
|
||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||
and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||
@@ -273,21 +315,10 @@ class OAuthManager:
|
||||
if not user:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
if (
|
||||
request.app.state.USER_COUNT
|
||||
and user_count >= request.app.state.USER_COUNT
|
||||
):
|
||||
raise HTTPException(
|
||||
403,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(
|
||||
user_data.get("email", "").lower()
|
||||
)
|
||||
existing_user = Users.get_user_by_email(email)
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
|
||||
from typing import Callable, Optional
|
||||
import json
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
@@ -67,38 +68,49 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
|
||||
|
||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
opts = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"seed",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"num_ctx",
|
||||
"num_batch",
|
||||
"num_keep",
|
||||
"repeat_last_n",
|
||||
"tfs_z",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"use_mmap",
|
||||
"use_mlock",
|
||||
"num_thread",
|
||||
"num_gpu",
|
||||
]
|
||||
mappings = {i: lambda x: x for i in opts}
|
||||
form_data = apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
# Convert OpenAI parameter names to Ollama parameter names if needed.
|
||||
name_differences = {
|
||||
"max_tokens": "num_predict",
|
||||
"frequency_penalty": "repeat_penalty",
|
||||
}
|
||||
|
||||
for key, value in name_differences.items():
|
||||
if (param := params.get(key, None)) is not None:
|
||||
form_data[value] = param
|
||||
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||
params[value] = params[key]
|
||||
del params[key]
|
||||
|
||||
return form_data
|
||||
# See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": float,
|
||||
"seed": lambda x: x,
|
||||
"mirostat": int,
|
||||
"mirostat_eta": float,
|
||||
"mirostat_tau": float,
|
||||
"num_ctx": int,
|
||||
"num_batch": int,
|
||||
"num_keep": int,
|
||||
"num_predict": int,
|
||||
"repeat_last_n": int,
|
||||
"top_k": int,
|
||||
"min_p": float,
|
||||
"typical_p": float,
|
||||
"repeat_penalty": float,
|
||||
"presence_penalty": float,
|
||||
"frequency_penalty": float,
|
||||
"penalize_newline": bool,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
"numa": bool,
|
||||
"num_gpu": int,
|
||||
"main_gpu": int,
|
||||
"low_vram": bool,
|
||||
"vocab_only": bool,
|
||||
"use_mmap": bool,
|
||||
"use_mlock": bool,
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
|
||||
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
@@ -109,11 +121,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
new_message = {"role": message["role"]}
|
||||
|
||||
content = message.get("content", [])
|
||||
tool_calls = message.get("tool_calls", None)
|
||||
tool_call_id = message.get("tool_call_id", None)
|
||||
|
||||
# Check if the content is a string (just a simple message)
|
||||
if isinstance(content, str):
|
||||
if isinstance(content, str) and not tool_calls:
|
||||
# If the content is a string, it's pure text
|
||||
new_message["content"] = content
|
||||
|
||||
# If message is a tool call, add the tool call id to the message
|
||||
if tool_call_id:
|
||||
new_message["tool_call_id"] = tool_call_id
|
||||
|
||||
elif tool_calls:
|
||||
# If tool calls are present, add them to the message
|
||||
ollama_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
ollama_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", None),
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.loads(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
ollama_tool_calls.append(ollama_tool_call)
|
||||
new_message["tool_calls"] = ollama_tool_calls
|
||||
|
||||
# Put the content to empty string (Ollama requires an empty string for tool calls)
|
||||
new_message["content"] = ""
|
||||
|
||||
else:
|
||||
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
|
||||
content_text = ""
|
||||
@@ -174,33 +213,28 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
ollama_payload["format"] = openai_payload["format"]
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
ollama_options = {}
|
||||
|
||||
if openai_payload.get("options"):
|
||||
ollama_payload["options"] = openai_payload["options"]
|
||||
ollama_options = openai_payload["options"]
|
||||
|
||||
# Handle parameters which map directly
|
||||
for param in ["temperature", "top_p", "seed"]:
|
||||
if param in openai_payload:
|
||||
ollama_options[param] = openai_payload[param]
|
||||
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_tokens" in ollama_options:
|
||||
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||
del ollama_options[
|
||||
"max_tokens"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_completion_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
|
||||
elif "max_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_tokens"]
|
||||
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
|
||||
if "system" in ollama_options:
|
||||
ollama_payload["system"] = ollama_options["system"]
|
||||
del ollama_options[
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# Handle frequency / presence_penalty, which needs renaming and checking
|
||||
if "frequency_penalty" in openai_payload:
|
||||
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
|
||||
|
||||
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
|
||||
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
|
||||
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
|
||||
|
||||
# Add options to payload if any have been set
|
||||
if ollama_options:
|
||||
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||
if "stop" in openai_payload:
|
||||
ollama_options = ollama_payload.get("options", {})
|
||||
ollama_options["stop"] = openai_payload.get("stop")
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
if "metadata" in openai_payload:
|
||||
|
||||
@@ -45,7 +45,7 @@ def extract_frontmatter(content):
|
||||
frontmatter[key.strip()] = value.strip()
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Failed to extract frontmatter: {e}")
|
||||
return {}
|
||||
|
||||
return frontmatter
|
||||
|
||||
@@ -24,17 +24,8 @@ def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||
return openai_tool_calls
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
data = ollama_response
|
||||
usage = {
|
||||
def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||
return {
|
||||
"response_token/s": (
|
||||
round(
|
||||
(
|
||||
@@ -66,14 +57,42 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
"total_duration": data.get("total_duration", 0),
|
||||
"load_duration": data.get("load_duration", 0),
|
||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||
"prompt_tokens": int(
|
||||
data.get("prompt_eval_count", 0)
|
||||
), # This is the OpenAI compatible key
|
||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
"completion_tokens": int(
|
||||
data.get("eval_count", 0)
|
||||
), # This is the OpenAI compatible key
|
||||
"eval_duration": data.get("eval_duration", 0),
|
||||
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
|
||||
(data.get("total_duration", 0) or 0) // 1_000_000_000
|
||||
),
|
||||
"total_tokens": int( # This is the OpenAI compatible key
|
||||
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
|
||||
),
|
||||
"completion_tokens_details": { # This is the OpenAI compatible key
|
||||
"reasoning_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
data = ollama_response
|
||||
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
response = openai_chat_completion_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
@@ -85,7 +104,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
data = json.loads(data)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", "")
|
||||
message_content = data.get("message", {}).get("content", None)
|
||||
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
@@ -96,48 +115,10 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
|
||||
usage = None
|
||||
if done:
|
||||
usage = {
|
||||
"response_token/s": (
|
||||
round(
|
||||
(
|
||||
(
|
||||
data.get("eval_count", 0)
|
||||
/ ((data.get("eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
2,
|
||||
)
|
||||
if data.get("eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"prompt_token/s": (
|
||||
round(
|
||||
(
|
||||
(
|
||||
data.get("prompt_eval_count", 0)
|
||||
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
2,
|
||||
)
|
||||
if data.get("prompt_eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"total_duration": data.get("total_duration", 0),
|
||||
"load_duration": data.get("load_duration", 0),
|
||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
"eval_duration": data.get("eval_duration", 0),
|
||||
"approximate_total": (
|
||||
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
|
||||
)((data.get("total_duration", 0) or 0) // 1_000_000_000),
|
||||
}
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content if not done else None, openai_tool_calls, usage
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_task_model_id(
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id]["owned_by"] == "ollama":
|
||||
if models[task_model_id].get("owned_by") == "ollama":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user