feat: config.json db migration

This commit is contained in:
Timothy J. Baek 2024-08-25 16:52:36 +02:00
parent 072945c40b
commit 58cf1be20c
16 changed files with 432 additions and 322 deletions

View File

@ -3,18 +3,19 @@ import logging
import json import json
from contextlib import contextmanager from contextlib import contextmanager
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection
from typing import Optional, Any from typing import Optional, Any
from typing_extensions import Self from typing_extensions import Self
from sqlalchemy import create_engine, types, Dialect from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.sql.type_api import _T
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.sql.type_api import _T
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection
from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
@ -42,15 +43,6 @@ class JSONField(types.TypeDecorator):
return json.loads(value) return json.loads(value)
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass
# Workaround to handle the peewee migration # Workaround to handle the peewee migration
# This is required to ensure the peewee migration is handled before the alembic migration # This is required to ensure the peewee migration is handled before the alembic migration
def handle_peewee_migration(DATABASE_URL): def handle_peewee_migration(DATABASE_URL):
@ -94,7 +86,6 @@ Base = declarative_base()
Session = scoped_session(SessionLocal) Session = scoped_session(SessionLocal)
# Dependency
def get_session(): def get_session():
db = SessionLocal() db = SessionLocal()
try: try:

View File

@ -6,7 +6,7 @@ import logging
from playhouse.db_url import connect, parse from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin from playhouse.shortcuts import ReconnectMixin
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])

View File

@ -4,12 +4,12 @@ import uuid
import logging import logging
from sqlalchemy import String, Column, Boolean, Text from sqlalchemy import String, Column, Boolean, Text
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.webui.models.users import UserModel, Users
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -9,7 +9,7 @@ from apps.webui.internal.db import Base, get_db
import json import json
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -9,7 +9,7 @@ from apps.webui.internal.db import JSONField, Base, get_db
import json import json
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -12,7 +12,7 @@ import json
import copy import copy
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -6,7 +6,7 @@ from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
import time import time

View File

@ -10,7 +10,7 @@ from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -11,7 +11,7 @@ import json
import copy import copy
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -1,13 +1,17 @@
from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func
from contextlib import contextmanager
import os import os
import sys import sys
import logging import logging
import importlib.metadata import importlib.metadata
import pkgutil import pkgutil
from urllib.parse import urlparse from urllib.parse import urlparse
from datetime import datetime
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from bs4 import BeautifulSoup
from typing import TypeVar, Generic from typing import TypeVar, Generic
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
@ -16,68 +20,39 @@ from pathlib import Path
import json import json
import yaml import yaml
import markdown
import requests import requests
import shutil import shutil
from apps.webui.internal.db import Base, get_db
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### from env import (
# Load .env file ENV,
#################################### VERSION,
SAFE_MODE,
BACKEND_DIR = Path(__file__).parent # the path containing this file GLOBAL_LOG_LEVEL,
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ SRC_LOG_LEVELS,
BASE_DIR,
print(BASE_DIR) DATA_DIR,
BACKEND_DIR,
try: FRONTEND_BUILD_DIR,
from dotenv import load_dotenv, find_dotenv WEBUI_NAME,
WEBUI_URL,
load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) WEBUI_FAVICON_URL,
except ImportError: WEBUI_BUILD_HASH,
print("dotenv not installed, skipping...") CONFIG_DATA,
DATABASE_URL,
CHANGELOG,
#################################### WEBUI_AUTH,
# LOGGING WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
#################################### WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SECRET_KEY,
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() log,
if GLOBAL_LOG_LEVEL in log_levels: )
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
class EndpointFilter(logging.Filter): class EndpointFilter(logging.Filter):
@ -88,135 +63,60 @@ class EndpointFilter(logging.Filter):
# Filter out /endpoint # Filter out /endpoint
logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except Exception:
try:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
except importlib.metadata.PackageNotFoundError:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# WEBUI_BUILD_HASH
####################################
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except Exception:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except Exception:
CONFIG_DATA = {}
#################################### ####################################
# Config helpers # Config helpers
#################################### ####################################
# Function to run the alembic migrations
def run_migrations():
print("Running migrations")
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
run_migrations()
class Config(Base):
__tablename__ = "config"
id = Column(Integer, primary_key=True)
data = Column(JSON, nullable=False)
version = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, nullable=False, server_default=func.now())
updated_at = Column(DateTime, nullable=True, onupdate=func.now())
def load_initial_config():
with open(f"{DATA_DIR}/config.json", "r") as file:
return json.load(file)
def save_to_db(data):
with get_db() as db:
existing_config = db.query(Config).first()
if not existing_config:
new_config = Config(data=data, version=0)
db.add(new_config)
else:
existing_config.data = data
db.commit()
# When initializing, check if config.json exists and migrate it to the database
if os.path.exists(f"{DATA_DIR}/config.json"):
data = load_initial_config()
save_to_db(data)
os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json")
def save_config(): def save_config():
try: try:
with open(f"{DATA_DIR}/config.json", "w") as f: with open(f"{DATA_DIR}/config.json", "w") as f:
@ -244,9 +144,9 @@ class PersistentConfig(Generic[T]):
self.env_name = env_name self.env_name = env_name
self.config_path = config_path self.config_path = config_path
self.env_value = env_value self.env_value = env_value
self.config_value = get_config_value(config_path) self.config_value = self.load_latest_config_value(config_path)
if self.config_value is not None: if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json") log.info(f"'{env_name}' loaded from the latest database entry")
self.value = self.config_value self.value = self.config_value
else: else:
self.value = env_value self.value = env_value
@ -254,33 +154,44 @@ class PersistentConfig(Generic[T]):
def __str__(self): def __str__(self):
return str(self.value) return str(self.value)
@property def load_latest_config_value(self, config_path: str):
def __dict__(self): with get_db() as db:
raise TypeError( config_entry = db.query(Config).order_by(Config.id.desc()).first()
"PersistentConfig object cannot be converted to dict, use config_get or .value instead." if config_entry:
) try:
path_parts = config_path.split(".")
def __getattribute__(self, item): config_value = config_entry.data
if item == "__dict__": for key in path_parts:
raise TypeError( config_value = config_value[key]
"PersistentConfig object cannot be converted to dict, use config_get or .value instead." return config_value
) except KeyError:
return super().__getattribute__(item) return None
def save(self): def save(self):
# Don't save if the value is the same as the env value and the config value if self.env_value == self.value and self.config_value == self.value:
if self.env_value == self.value: return
if self.config_value == self.value: log.info(f"Saving '{self.env_name}' to the database")
return
log.info(f"Saving '{self.env_name}' to config.json")
path_parts = self.config_path.split(".") path_parts = self.config_path.split(".")
config = CONFIG_DATA with get_db() as db:
for key in path_parts[:-1]: existing_config = db.query(Config).first()
if key not in config: if existing_config:
config[key] = {} config = existing_config.data
config = config[key] for key in path_parts[:-1]:
config[path_parts[-1]] = self.value if key not in config:
save_config() config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
existing_config.version += 1
else: # This case should not actually occur as there should always be at least one entry
new_data = {}
config = new_data
for key in path_parts[:-1]:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
new_config = Config(data=new_data, version=1)
db.add(new_config)
db.commit()
self.config_value = self.value self.config_value = self.value
@ -305,11 +216,6 @@ class AppConfig:
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
#################################### ####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
JWT_EXPIRES_IN = PersistentConfig( JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
) )
@ -999,30 +905,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
) )
####################################
# WEBUI_SECRET_KEY
####################################
WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
#################################### ####################################
# RAG document content extraction # RAG document content extraction
#################################### ####################################
@ -1553,14 +1435,3 @@ AUDIO_TTS_VOICE = PersistentConfig(
"audio.tts.voice", "audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice
) )
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")

View File

@ -1,36 +0,0 @@
{
"version": 0,
"ui": {
"default_locale": "",
"prompt_suggestions": [
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
},
{
"title": ["Give me ideas", "for what to do with my kids' art"],
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
},
{
"title": ["Tell me a fun fact", "about the Roman Empire"],
"content": "Tell me a random fun fact about the Roman Empire"
},
{
"title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
},
{
"title": ["Explain options trading", "if I'm familiar with buying and selling stocks"],
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks."
},
{
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?"
},
{
"title": ["Grammar check", "rewrite it for better readability "],
"content": "Check the following sentence for grammar and clarity: \"[sentence]\". Rewrite it for better readability while maintaining its original meaning."
}
]
}
}

252
backend/env.py Normal file
View File

@ -0,0 +1,252 @@
from pathlib import Path
import os
import logging
import sys
import json
import importlib.metadata
import pkgutil
from urllib.parse import urlparse
from datetime import datetime
import markdown
from bs4 import BeautifulSoup
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
BACKEND_DIR = Path(__file__).parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR)
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
except ImportError:
print("dotenv not installed, skipping...")
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except Exception:
try:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
except importlib.metadata.PackageNotFoundError:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# WEBUI_BUILD_HASH
####################################
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except Exception:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except Exception:
CONFIG_DATA = {}
####################################
# Database
####################################
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
####################################
# WEBUI_SECRET_KEY
####################################
WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)

View File

@ -1,7 +1,6 @@
import base64 import base64
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
import json import json
@ -87,6 +86,7 @@ from utils.misc import (
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
from config import ( from config import (
run_migrations,
WEBUI_NAME, WEBUI_NAME,
WEBUI_URL, WEBUI_URL,
WEBUI_AUTH, WEBUI_AUTH,
@ -165,17 +165,6 @@ https://github.com/open-webui/open-webui
) )
def run_migrations():
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
run_migrations() run_migrations()

View File

@ -18,7 +18,7 @@ from apps.webui.models.users import User
from apps.webui.models.files import File from apps.webui.models.files import File
from apps.webui.models.functions import Function from apps.webui.models.functions import Function
from config import DATABASE_URL from env import DATABASE_URL
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View File

@ -0,0 +1,43 @@
"""Add config table
Revision ID: ca81bd47c050
Revises: 7e5b5dc7342b
Create Date: 2024-08-25 15:26:35.241684
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "ca81bd47c050"
down_revision: Union[str, None] = "7e5b5dc7342b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade():
op.create_table(
"config",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("data", sa.JSON(), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column(
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
),
sa.Column(
"updated_at",
sa.DateTime(),
nullable=True,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
)
def downgrade():
op.drop_table("config")

View File

@ -10,12 +10,12 @@ from datetime import datetime, timedelta, UTC
import jwt import jwt
import uuid import uuid
import logging import logging
import config from env import WEBUI_SECRET_KEY
logging.getLogger("passlib").setLevel(logging.ERROR) logging.getLogger("passlib").setLevel(logging.ERROR)
SESSION_SECRET = config.WEBUI_SECRET_KEY SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256" ALGORITHM = "HS256"
############## ##############