diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index e30e3c2b3..db8df5ee5 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -3,18 +3,19 @@ import logging import json from contextlib import contextmanager -from peewee_migrate import Router -from apps.webui.internal.wrappers import register_connection from typing import Optional, Any from typing_extensions import Self from sqlalchemy import create_engine, types, Dialect +from sqlalchemy.sql.type_api import _T from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.sql.type_api import _T -from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR + +from peewee_migrate import Router +from apps.webui.internal.wrappers import register_connection +from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) @@ -42,15 +43,6 @@ class JSONField(types.TypeDecorator): return json.loads(value) -# Check if the file exists -if os.path.exists(f"{DATA_DIR}/ollama.db"): - # Rename the file - os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") - log.info("Database migrated from Ollama-WebUI successfully.") -else: - pass - - # Workaround to handle the peewee migration # This is required to ensure the peewee migration is handled before the alembic migration def handle_peewee_migration(DATABASE_URL): @@ -94,7 +86,6 @@ Base = declarative_base() Session = scoped_session(SessionLocal) -# Dependency def get_session(): db = SessionLocal() try: diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py index cc4a42421..19523064a 100644 --- a/backend/apps/webui/internal/wrappers.py +++ b/backend/apps/webui/internal/wrappers.py @@ -6,7 +6,7 @@ import logging from playhouse.db_url import connect, parse from playhouse.shortcuts import ReconnectMixin -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 3cbe8c887..601c7c9a4 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -4,12 +4,12 @@ import uuid import logging from sqlalchemy import String, Column, Boolean, Text -from apps.webui.models.users import UserModel, Users from utils.utils import verify_password +from apps.webui.models.users import UserModel, Users from apps.webui.internal.db import Base, get_db -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 4157c2c95..15dd63663 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -9,7 +9,7 @@ from apps.webui.internal.db import Base, get_db import json -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index e1d1cec9f..1b7175124 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -9,7 +9,7 @@ from apps.webui.internal.db import JSONField, Base, get_db import json -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 3afdc1ea9..10d811148 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -12,7 +12,7 @@ import json import copy -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 616beb2a9..0a36da987 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -6,7 +6,7 @@ from sqlalchemy import Column, BigInteger, Text from apps.webui.internal.db import Base, JSONField, get_db -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS import time diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 7ce06cb60..605cca2e7 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -10,7 +10,7 @@ from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, get_db -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index c8c56fb97..2f4c532b8 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -11,7 +11,7 @@ import json import copy -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/config.py b/backend/config.py index d7caadca4..4dde80eda 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,13 +1,17 @@ +from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func +from contextlib import contextmanager + + import os import sys import logging import importlib.metadata import pkgutil from urllib.parse import urlparse +from datetime import datetime import chromadb from chromadb import Settings -from bs4 import BeautifulSoup from typing import TypeVar, Generic from pydantic import BaseModel from typing import Optional @@ -16,68 +20,39 @@ from pathlib import Path import json import yaml -import markdown import requests import shutil + +from apps.webui.internal.db import Base, get_db + from constants import ERROR_MESSAGES -#################################### -# Load .env file -#################################### - -BACKEND_DIR = Path(__file__).parent # the path containing this file -BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ - -print(BASE_DIR) - -try: - from dotenv import load_dotenv, find_dotenv - - load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) -except ImportError: - print("dotenv not installed, skipping...") - - -#################################### -# LOGGING -#################################### - -log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] - -GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() -if GLOBAL_LOG_LEVEL in log_levels: - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) -else: - GLOBAL_LOG_LEVEL = "INFO" - -log = logging.getLogger(__name__) -log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") - -log_sources = [ - "AUDIO", - "COMFYUI", - "CONFIG", - "DB", - "IMAGES", - "MAIN", - "MODELS", - "OLLAMA", - "OPENAI", - "RAG", - "WEBHOOK", -] - -SRC_LOG_LEVELS = {} - -for source in log_sources: - log_env_var = source + "_LOG_LEVEL" - SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() - if SRC_LOG_LEVELS[source] not in log_levels: - SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL - log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") - -log.setLevel(SRC_LOG_LEVELS["CONFIG"]) +from env import ( + ENV, + VERSION, + SAFE_MODE, + GLOBAL_LOG_LEVEL, + SRC_LOG_LEVELS, + BASE_DIR, + DATA_DIR, + BACKEND_DIR, + FRONTEND_BUILD_DIR, + WEBUI_NAME, + WEBUI_URL, + WEBUI_FAVICON_URL, + WEBUI_BUILD_HASH, + CONFIG_DATA, + DATABASE_URL, + CHANGELOG, + WEBUI_AUTH, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, + log, +) class EndpointFilter(logging.Filter): @@ -88,135 +63,62 @@ class EndpointFilter(logging.Filter): # Filter out /endpoint logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) - -WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") -if WEBUI_NAME != "Open WebUI": - WEBUI_NAME += " (Open WebUI)" - -WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") - -WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" - - -#################################### -# ENV (dev,test,prod) -#################################### - -ENV = os.environ.get("ENV", "dev") - -try: - PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) -except Exception: - try: - PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} - except importlib.metadata.PackageNotFoundError: - PACKAGE_DATA = {"version": "0.0.0"} - -VERSION = PACKAGE_DATA["version"] - - -# Function to parse each section -def parse_section(section): - items = [] - for li in section.find_all("li"): - # Extract raw HTML string - raw_html = str(li) - - # Extract text without HTML tags - text = li.get_text(separator=" ", strip=True) - - # Split into title and content - parts = text.split(": ", 1) - title = parts[0].strip() if len(parts) > 1 else "" - content = parts[1].strip() if len(parts) > 1 else text - - items.append({"title": title, "content": content, "raw": raw_html}) - return items - - -try: - changelog_path = BASE_DIR / "CHANGELOG.md" - with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: - changelog_content = file.read() - -except Exception: - changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() - - -# Convert markdown content to HTML -html_content = markdown.markdown(changelog_content) - -# Parse the HTML content -soup = BeautifulSoup(html_content, "html.parser") - -# Initialize JSON structure -changelog_json = {} - -# Iterate over each version -for version in soup.find_all("h2"): - version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets - date = version.get_text().strip().split(" - ")[1] - - version_data = {"date": date} - - # Find the next sibling that is a h3 tag (section title) - current = version.find_next_sibling() - - while current and current.name != "h2": - if current.name == "h3": - section_title = current.get_text().lower() # e.g., "added", "fixed" - section_items = parse_section(current.find_next_sibling("ul")) - version_data[section_title] = section_items - - # Move to the next element - current = current.find_next_sibling() - - changelog_json[version_number] = version_data - - -CHANGELOG = changelog_json - -#################################### -# SAFE_MODE -#################################### - -SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" - -#################################### -# WEBUI_BUILD_HASH -#################################### - -WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") - -#################################### -# DATA/FRONTEND BUILD DIR -#################################### - -DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() -FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() - -RESET_CONFIG_ON_START = ( - os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" -) -if RESET_CONFIG_ON_START: - try: - os.remove(f"{DATA_DIR}/config.json") - with open(f"{DATA_DIR}/config.json", "w") as f: - f.write("{}") - except Exception: - pass - -try: - CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text()) -except Exception: - CONFIG_DATA = {} - - #################################### # Config helpers #################################### +# Function to run the alembic migrations +def run_migrations(): + print("Running migrations") + try: + from alembic.config import Config + from alembic import command + + alembic_cfg = Config("alembic.ini") + command.upgrade(alembic_cfg, "head") + except Exception as e: + print(f"Error: {e}") + + +run_migrations() + + +class Config(Base): + __tablename__ = "config" + + id = Column(Integer, primary_key=True) + data = Column(JSON, nullable=False) + version = Column(Integer, nullable=False, default=0) + created_at = Column(DateTime, nullable=False, server_default=func.now()) + updated_at = Column(DateTime, nullable=True, onupdate=func.now()) + + +def load_json_config(): + with open(f"{DATA_DIR}/config.json", "r") as file: + return json.load(file) + + +def save_to_db(data): + with get_db() as db: + existing_config = db.query(Config).first() + if not existing_config: + new_config = Config(data=data, version=0) + db.add(new_config) + else: + existing_config.data = data + existing_config.updated_at = datetime.now() + db.add(existing_config) + db.commit() + + +# When initializing, check if config.json exists and migrate it to the database +if os.path.exists(f"{DATA_DIR}/config.json"): + data = load_json_config() + save_to_db(data) + os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json") + + def save_config(): try: with open(f"{DATA_DIR}/config.json", "w") as f: @@ -225,6 +127,15 @@ def save_config(): log.exception(e) +def get_config(): + with get_db() as db: + config_entry = db.query(Config).order_by(Config.id.desc()).first() + return config_entry.data if config_entry else {} + + +CONFIG_DATA = get_config() + + def get_config_value(config_path: str): path_parts = config_path.split(".") cur_config = CONFIG_DATA @@ -246,7 +157,7 @@ class PersistentConfig(Generic[T]): self.env_value = env_value self.config_value = get_config_value(config_path) if self.config_value is not None: - log.info(f"'{env_name}' loaded from config.json") + log.info(f"'{env_name}' loaded from the latest database entry") self.value = self.config_value else: self.value = env_value @@ -268,19 +179,15 @@ class PersistentConfig(Generic[T]): return super().__getattribute__(item) def save(self): - # Don't save if the value is the same as the env value and the config value - if self.env_value == self.value: - if self.config_value == self.value: - return - log.info(f"Saving '{self.env_name}' to config.json") + log.info(f"Saving '{self.env_name}' to the database") path_parts = self.config_path.split(".") - config = CONFIG_DATA + sub_config = CONFIG_DATA for key in path_parts[:-1]: - if key not in config: - config[key] = {} - config = config[key] - config[path_parts[-1]] = self.value - save_config() + if key not in sub_config: + sub_config[key] = {} + sub_config = sub_config[key] + sub_config[path_parts[-1]] = self.value + save_to_db(CONFIG_DATA) self.config_value = self.value @@ -305,11 +212,6 @@ class AppConfig: # WEBUI_AUTH (Required for security) #################################### -WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" -WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( - "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None -) -WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -999,30 +901,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( ) -#################################### -# WEBUI_SECRET_KEY -#################################### - -WEBUI_SECRET_KEY = os.environ.get( - "WEBUI_SECRET_KEY", - os.environ.get( - "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" - ), # DEPRECATED: remove at next major version -) - -WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get( - "WEBUI_SESSION_COOKIE_SAME_SITE", - os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"), -) - -WEBUI_SESSION_COOKIE_SECURE = os.environ.get( - "WEBUI_SESSION_COOKIE_SECURE", - os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true", -) - -if WEBUI_AUTH and WEBUI_SECRET_KEY == "": - raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) - #################################### # RAG document content extraction #################################### @@ -1553,14 +1431,3 @@ AUDIO_TTS_VOICE = PersistentConfig( "audio.tts.voice", os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice ) - - -#################################### -# Database -#################################### - -DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") - -# Replace the postgres:// with postgresql:// -if "postgres://" in DATABASE_URL: - DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") diff --git a/backend/data/config.json b/backend/data/config.json deleted file mode 100644 index 7c7acde91..000000000 --- a/backend/data/config.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "version": 0, - "ui": { - "default_locale": "", - "prompt_suggestions": [ - { - "title": ["Help me study", "vocabulary for a college entrance exam"], - "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option." - }, - { - "title": ["Give me ideas", "for what to do with my kids' art"], - "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter." - }, - { - "title": ["Tell me a fun fact", "about the Roman Empire"], - "content": "Tell me a random fun fact about the Roman Empire" - }, - { - "title": ["Show me a code snippet", "of a website's sticky header"], - "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript." - }, - { - "title": ["Explain options trading", "if I'm familiar with buying and selling stocks"], - "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks." - }, - { - "title": ["Overcome procrastination", "give me tips"], - "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?" - }, - { - "title": ["Grammar check", "rewrite it for better readability "], - "content": "Check the following sentence for grammar and clarity: \"[sentence]\". Rewrite it for better readability while maintaining its original meaning." - } - ] - } -} diff --git a/backend/env.py b/backend/env.py new file mode 100644 index 000000000..689dc1b6d --- /dev/null +++ b/backend/env.py @@ -0,0 +1,252 @@ +from pathlib import Path +import os +import logging +import sys +import json + + +import importlib.metadata +import pkgutil +from urllib.parse import urlparse +from datetime import datetime + + +import markdown +from bs4 import BeautifulSoup + +from constants import ERROR_MESSAGES + +#################################### +# Load .env file +#################################### + +BACKEND_DIR = Path(__file__).parent # the path containing this file +BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ + +print(BASE_DIR) + +try: + from dotenv import load_dotenv, find_dotenv + + load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) +except ImportError: + print("dotenv not installed, skipping...") + + +#################################### +# LOGGING +#################################### + +log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] + +GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() +if GLOBAL_LOG_LEVEL in log_levels: + logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) +else: + GLOBAL_LOG_LEVEL = "INFO" + +log = logging.getLogger(__name__) +log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") + +log_sources = [ + "AUDIO", + "COMFYUI", + "CONFIG", + "DB", + "IMAGES", + "MAIN", + "MODELS", + "OLLAMA", + "OPENAI", + "RAG", + "WEBHOOK", +] + +SRC_LOG_LEVELS = {} + +for source in log_sources: + log_env_var = source + "_LOG_LEVEL" + SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() + if SRC_LOG_LEVELS[source] not in log_levels: + SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL + log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") + +log.setLevel(SRC_LOG_LEVELS["CONFIG"]) + + +WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") +if WEBUI_NAME != "Open WebUI": + WEBUI_NAME += " (Open WebUI)" + +WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") + +WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" + + +#################################### +# ENV (dev,test,prod) +#################################### + +ENV = os.environ.get("ENV", "dev") + +try: + PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) +except Exception: + try: + PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} + except importlib.metadata.PackageNotFoundError: + PACKAGE_DATA = {"version": "0.0.0"} + +VERSION = PACKAGE_DATA["version"] + + +# Function to parse each section +def parse_section(section): + items = [] + for li in section.find_all("li"): + # Extract raw HTML string + raw_html = str(li) + + # Extract text without HTML tags + text = li.get_text(separator=" ", strip=True) + + # Split into title and content + parts = text.split(": ", 1) + title = parts[0].strip() if len(parts) > 1 else "" + content = parts[1].strip() if len(parts) > 1 else text + + items.append({"title": title, "content": content, "raw": raw_html}) + return items + + +try: + changelog_path = BASE_DIR / "CHANGELOG.md" + with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: + changelog_content = file.read() + +except Exception: + changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() + + +# Convert markdown content to HTML +html_content = markdown.markdown(changelog_content) + +# Parse the HTML content +soup = BeautifulSoup(html_content, "html.parser") + +# Initialize JSON structure +changelog_json = {} + +# Iterate over each version +for version in soup.find_all("h2"): + version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets + date = version.get_text().strip().split(" - ")[1] + + version_data = {"date": date} + + # Find the next sibling that is a h3 tag (section title) + current = version.find_next_sibling() + + while current and current.name != "h2": + if current.name == "h3": + section_title = current.get_text().lower() # e.g., "added", "fixed" + section_items = parse_section(current.find_next_sibling("ul")) + version_data[section_title] = section_items + + # Move to the next element + current = current.find_next_sibling() + + changelog_json[version_number] = version_data + + +CHANGELOG = changelog_json + +#################################### +# SAFE_MODE +#################################### + +SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" + +#################################### +# WEBUI_BUILD_HASH +#################################### + +WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") + +#################################### +# DATA/FRONTEND BUILD DIR +#################################### + +DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() +FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() + +RESET_CONFIG_ON_START = ( + os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" +) +if RESET_CONFIG_ON_START: + try: + os.remove(f"{DATA_DIR}/config.json") + with open(f"{DATA_DIR}/config.json", "w") as f: + f.write("{}") + except Exception: + pass + +try: + CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text()) +except Exception: + CONFIG_DATA = {} + + +#################################### +# Database +#################################### + +# Check if the file exists +if os.path.exists(f"{DATA_DIR}/ollama.db"): + # Rename the file + os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") + log.info("Database migrated from Ollama-WebUI successfully.") +else: + pass + +DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") + +# Replace the postgres:// with postgresql:// +if "postgres://" in DATABASE_URL: + DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") + + +#################################### +# WEBUI_AUTH (Required for security) +#################################### + +WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None +) +WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) + + +#################################### +# WEBUI_SECRET_KEY +#################################### + +WEBUI_SECRET_KEY = os.environ.get( + "WEBUI_SECRET_KEY", + os.environ.get( + "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" + ), # DEPRECATED: remove at next major version +) + +WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get( + "WEBUI_SESSION_COOKIE_SAME_SITE", + os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"), +) + +WEBUI_SESSION_COOKIE_SECURE = os.environ.get( + "WEBUI_SESSION_COOKIE_SECURE", + os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true", +) + +if WEBUI_AUTH and WEBUI_SECRET_KEY == "": + raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) diff --git a/backend/main.py b/backend/main.py index ab557bea1..b8ed68111 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,7 +1,6 @@ import base64 import uuid from contextlib import asynccontextmanager - from authlib.integrations.starlette_client import OAuth from authlib.oidc.core import UserInfo import json @@ -87,6 +86,7 @@ from utils.misc import ( from apps.rag.utils import get_rag_context, rag_template from config import ( + run_migrations, WEBUI_NAME, WEBUI_URL, WEBUI_AUTH, @@ -165,17 +165,6 @@ https://github.com/open-webui/open-webui ) -def run_migrations(): - try: - from alembic.config import Config - from alembic import command - - alembic_cfg = Config("alembic.ini") - command.upgrade(alembic_cfg, "head") - except Exception as e: - print(f"Error: {e}") - - @asynccontextmanager async def lifespan(app: FastAPI): run_migrations() diff --git a/backend/migrations/env.py b/backend/migrations/env.py index 8046abff3..b3b3407fa 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -18,7 +18,7 @@ from apps.webui.models.users import User from apps.webui.models.files import File from apps.webui.models.functions import Function -from config import DATABASE_URL +from env import DATABASE_URL # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/backend/migrations/versions/ca81bd47c050_add_config_table.py b/backend/migrations/versions/ca81bd47c050_add_config_table.py new file mode 100644 index 000000000..b9f708240 --- /dev/null +++ b/backend/migrations/versions/ca81bd47c050_add_config_table.py @@ -0,0 +1,43 @@ +"""Add config table + +Revision ID: ca81bd47c050 +Revises: 7e5b5dc7342b +Create Date: 2024-08-25 15:26:35.241684 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import apps.webui.internal.db + + +# revision identifiers, used by Alembic. +revision: str = "ca81bd47c050" +down_revision: Union[str, None] = "7e5b5dc7342b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + op.create_table( + "config", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("data", sa.JSON(), nullable=False), + sa.Column("version", sa.Integer, nullable=False), + sa.Column( + "created_at", sa.DateTime(), nullable=False, server_default=sa.func.now() + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=True, + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + + +def downgrade(): + op.drop_table("config") diff --git a/backend/utils/utils.py b/backend/utils/utils.py index a4cce38af..4c15ea237 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -10,12 +10,12 @@ from datetime import datetime, timedelta, UTC import jwt import uuid import logging -import config +from env import WEBUI_SECRET_KEY logging.getLogger("passlib").setLevel(logging.ERROR) -SESSION_SECRET = config.WEBUI_SECRET_KEY +SESSION_SECRET = WEBUI_SECRET_KEY ALGORITHM = "HS256" ##############