feat: config.json db migration

This commit is contained in:
Timothy J. Baek 2024-08-25 16:52:36 +02:00
parent 072945c40b
commit 58cf1be20c
16 changed files with 432 additions and 322 deletions

View File

@ -3,18 +3,19 @@ import logging
import json
from contextlib import contextmanager
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection
from typing import Optional, Any
from typing_extensions import Self
from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.sql.type_api import _T
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.sql.type_api import _T
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection
from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
@ -42,15 +43,6 @@ class JSONField(types.TypeDecorator):
return json.loads(value)
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass
# Workaround to handle the peewee migration
# This is required to ensure the peewee migration is handled before the alembic migration
def handle_peewee_migration(DATABASE_URL):
@ -94,7 +86,6 @@ Base = declarative_base()
Session = scoped_session(SessionLocal)
# Dependency
def get_session():
db = SessionLocal()
try:

View File

@ -6,7 +6,7 @@ import logging
from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])

View File

@ -4,12 +4,12 @@ import uuid
import logging
from sqlalchemy import String, Column, Boolean, Text
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password
from apps.webui.models.users import UserModel, Users
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -9,7 +9,7 @@ from apps.webui.internal.db import Base, get_db
import json
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -9,7 +9,7 @@ from apps.webui.internal.db import JSONField, Base, get_db
import json
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -12,7 +12,7 @@ import json
import copy
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -6,7 +6,7 @@ from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
import time

View File

@ -10,7 +10,7 @@ from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -11,7 +11,7 @@ import json
import copy
from config import SRC_LOG_LEVELS
from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -1,13 +1,17 @@
from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func
from contextlib import contextmanager
import os
import sys
import logging
import importlib.metadata
import pkgutil
from urllib.parse import urlparse
from datetime import datetime
import chromadb
from chromadb import Settings
from bs4 import BeautifulSoup
from typing import TypeVar, Generic
from pydantic import BaseModel
from typing import Optional
@ -16,68 +20,39 @@ from pathlib import Path
import json
import yaml
import markdown
import requests
import shutil
from apps.webui.internal.db import Base, get_db
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
BACKEND_DIR = Path(__file__).parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR)
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
except ImportError:
print("dotenv not installed, skipping...")
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
from env import (
ENV,
VERSION,
SAFE_MODE,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
BASE_DIR,
DATA_DIR,
BACKEND_DIR,
FRONTEND_BUILD_DIR,
WEBUI_NAME,
WEBUI_URL,
WEBUI_FAVICON_URL,
WEBUI_BUILD_HASH,
CONFIG_DATA,
DATABASE_URL,
CHANGELOG,
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
log,
)
class EndpointFilter(logging.Filter):
@ -88,135 +63,60 @@ class EndpointFilter(logging.Filter):
# Filter out /endpoint
logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except Exception:
try:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
except importlib.metadata.PackageNotFoundError:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# WEBUI_BUILD_HASH
####################################
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except Exception:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except Exception:
CONFIG_DATA = {}
####################################
# Config helpers
####################################
# Function to run the alembic migrations
def run_migrations():
print("Running migrations")
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
run_migrations()
class Config(Base):
__tablename__ = "config"
id = Column(Integer, primary_key=True)
data = Column(JSON, nullable=False)
version = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, nullable=False, server_default=func.now())
updated_at = Column(DateTime, nullable=True, onupdate=func.now())
def load_initial_config():
with open(f"{DATA_DIR}/config.json", "r") as file:
return json.load(file)
def save_to_db(data):
with get_db() as db:
existing_config = db.query(Config).first()
if not existing_config:
new_config = Config(data=data, version=0)
db.add(new_config)
else:
existing_config.data = data
db.commit()
# When initializing, check if config.json exists and migrate it to the database
if os.path.exists(f"{DATA_DIR}/config.json"):
data = load_initial_config()
save_to_db(data)
os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json")
def save_config():
try:
with open(f"{DATA_DIR}/config.json", "w") as f:
@ -244,9 +144,9 @@ class PersistentConfig(Generic[T]):
self.env_name = env_name
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
self.config_value = self.load_latest_config_value(config_path)
if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json")
log.info(f"'{env_name}' loaded from the latest database entry")
self.value = self.config_value
else:
self.value = env_value
@ -254,33 +154,44 @@ class PersistentConfig(Generic[T]):
def __str__(self):
return str(self.value)
@property
def __dict__(self):
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def __getattribute__(self, item):
if item == "__dict__":
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return super().__getattribute__(item)
def load_latest_config_value(self, config_path: str):
with get_db() as db:
config_entry = db.query(Config).order_by(Config.id.desc()).first()
if config_entry:
try:
path_parts = config_path.split(".")
config_value = config_entry.data
for key in path_parts:
config_value = config_value[key]
return config_value
except KeyError:
return None
def save(self):
# Don't save if the value is the same as the env value and the config value
if self.env_value == self.value:
if self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to config.json")
if self.env_value == self.value and self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to the database")
path_parts = self.config_path.split(".")
config = CONFIG_DATA
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
save_config()
with get_db() as db:
existing_config = db.query(Config).first()
if existing_config:
config = existing_config.data
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
existing_config.version += 1
else: # This case should not actually occur as there should always be at least one entry
new_data = {}
config = new_data
for key in path_parts[:-1]:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
new_config = Config(data=new_data, version=1)
db.add(new_config)
db.commit()
self.config_value = self.value
@ -305,11 +216,6 @@ class AppConfig:
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
@ -999,30 +905,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
)
####################################
# WEBUI_SECRET_KEY
####################################
WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
@ -1553,14 +1435,3 @@ AUDIO_TTS_VOICE = PersistentConfig(
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice
)
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")

View File

@ -1,36 +0,0 @@
{
"version": 0,
"ui": {
"default_locale": "",
"prompt_suggestions": [
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
},
{
"title": ["Give me ideas", "for what to do with my kids' art"],
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
},
{
"title": ["Tell me a fun fact", "about the Roman Empire"],
"content": "Tell me a random fun fact about the Roman Empire"
},
{
"title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
},
{
"title": ["Explain options trading", "if I'm familiar with buying and selling stocks"],
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks."
},
{
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?"
},
{
"title": ["Grammar check", "rewrite it for better readability "],
"content": "Check the following sentence for grammar and clarity: \"[sentence]\". Rewrite it for better readability while maintaining its original meaning."
}
]
}
}

252
backend/env.py Normal file
View File

@ -0,0 +1,252 @@
from pathlib import Path
import os
import logging
import sys
import json
import importlib.metadata
import pkgutil
from urllib.parse import urlparse
from datetime import datetime
import markdown
from bs4 import BeautifulSoup
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
BACKEND_DIR = Path(__file__).parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR)
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
except ImportError:
print("dotenv not installed, skipping...")
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except Exception:
try:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
except importlib.metadata.PackageNotFoundError:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# WEBUI_BUILD_HASH
####################################
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except Exception:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except Exception:
CONFIG_DATA = {}
####################################
# Database
####################################
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
####################################
# WEBUI_SECRET_KEY
####################################
WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)

View File

@ -1,7 +1,6 @@
import base64
import uuid
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
import json
@ -87,6 +86,7 @@ from utils.misc import (
from apps.rag.utils import get_rag_context, rag_template
from config import (
run_migrations,
WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
@ -165,17 +165,6 @@ https://github.com/open-webui/open-webui
)
def run_migrations():
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
run_migrations()

View File

@ -18,7 +18,7 @@ from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from config import DATABASE_URL
from env import DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.

View File

@ -0,0 +1,43 @@
"""Add config table
Revision ID: ca81bd47c050
Revises: 7e5b5dc7342b
Create Date: 2024-08-25 15:26:35.241684
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "ca81bd47c050"
down_revision: Union[str, None] = "7e5b5dc7342b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade():
op.create_table(
"config",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("data", sa.JSON(), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column(
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
),
sa.Column(
"updated_at",
sa.DateTime(),
nullable=True,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
)
def downgrade():
op.drop_table("config")

View File

@ -10,12 +10,12 @@ from datetime import datetime, timedelta, UTC
import jwt
import uuid
import logging
import config
from env import WEBUI_SECRET_KEY
logging.getLogger("passlib").setLevel(logging.ERROR)
SESSION_SECRET = config.WEBUI_SECRET_KEY
SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256"
##############