mirror of
https://github.com/open-webui/open-webui
synced 2024-12-24 12:52:12 +00:00
1498 lines
41 KiB
Python
1498 lines
41 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Generic, Optional, TypeVar
|
|
from urllib.parse import urlparse
|
|
|
|
import chromadb
|
|
import requests
|
|
import yaml
|
|
from open_webui.apps.webui.internal.db import Base, get_db
|
|
from open_webui.env import (
|
|
OPEN_WEBUI_DIR,
|
|
DATA_DIR,
|
|
ENV,
|
|
FRONTEND_BUILD_DIR,
|
|
WEBUI_AUTH,
|
|
WEBUI_FAVICON_URL,
|
|
WEBUI_NAME,
|
|
log,
|
|
)
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
|
|
|
|
|
class EndpointFilter(logging.Filter):
|
|
def filter(self, record: logging.LogRecord) -> bool:
|
|
return record.getMessage().find("/health") == -1
|
|
|
|
|
|
# Filter out /endpoint
|
|
logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
|
|
|
|
####################################
|
|
# Config helpers
|
|
####################################
|
|
|
|
|
|
# Function to run the alembic migrations
|
|
def run_migrations():
|
|
print("Running migrations")
|
|
try:
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
|
|
alembic_cfg = Config(OPEN_WEBUI_DIR / "alembic.ini")
|
|
|
|
# Set the script location dynamically
|
|
migrations_path = OPEN_WEBUI_DIR / "migrations"
|
|
alembic_cfg.set_main_option("script_location", str(migrations_path))
|
|
|
|
command.upgrade(alembic_cfg, "head")
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
|
|
run_migrations()
|
|
|
|
|
|
class Config(Base):
|
|
__tablename__ = "config"
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
data = Column(JSON, nullable=False)
|
|
version = Column(Integer, nullable=False, default=0)
|
|
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
|
updated_at = Column(DateTime, nullable=True, onupdate=func.now())
|
|
|
|
|
|
def load_json_config():
|
|
with open(f"{DATA_DIR}/config.json", "r") as file:
|
|
return json.load(file)
|
|
|
|
|
|
def save_to_db(data):
|
|
with get_db() as db:
|
|
existing_config = db.query(Config).first()
|
|
if not existing_config:
|
|
new_config = Config(data=data, version=0)
|
|
db.add(new_config)
|
|
else:
|
|
existing_config.data = data
|
|
existing_config.updated_at = datetime.now()
|
|
db.add(existing_config)
|
|
db.commit()
|
|
|
|
|
|
def reset_config():
|
|
with get_db() as db:
|
|
db.query(Config).delete()
|
|
db.commit()
|
|
|
|
|
|
# When initializing, check if config.json exists and migrate it to the database
|
|
if os.path.exists(f"{DATA_DIR}/config.json"):
|
|
data = load_json_config()
|
|
save_to_db(data)
|
|
os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json")
|
|
|
|
DEFAULT_CONFIG = {
|
|
"version": 0,
|
|
"ui": {
|
|
"default_locale": "",
|
|
"prompt_suggestions": [
|
|
{
|
|
"title": [
|
|
"Help me study",
|
|
"vocabulary for a college entrance exam",
|
|
],
|
|
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
|
|
},
|
|
{
|
|
"title": [
|
|
"Give me ideas",
|
|
"for what to do with my kids' art",
|
|
],
|
|
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.",
|
|
},
|
|
{
|
|
"title": ["Tell me a fun fact", "about the Roman Empire"],
|
|
"content": "Tell me a random fun fact about the Roman Empire",
|
|
},
|
|
{
|
|
"title": [
|
|
"Show me a code snippet",
|
|
"of a website's sticky header",
|
|
],
|
|
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
|
|
},
|
|
{
|
|
"title": [
|
|
"Explain options trading",
|
|
"if I'm familiar with buying and selling stocks",
|
|
],
|
|
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
|
|
},
|
|
{
|
|
"title": ["Overcome procrastination", "give me tips"],
|
|
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
|
|
},
|
|
{
|
|
"title": [
|
|
"Grammar check",
|
|
"rewrite it for better readability ",
|
|
],
|
|
"content": 'Check the following sentence for grammar and clarity: "[sentence]". Rewrite it for better readability while maintaining its original meaning.',
|
|
},
|
|
],
|
|
},
|
|
}
|
|
|
|
|
|
def get_config():
|
|
with get_db() as db:
|
|
config_entry = db.query(Config).order_by(Config.id.desc()).first()
|
|
return config_entry.data if config_entry else DEFAULT_CONFIG
|
|
|
|
|
|
CONFIG_DATA = get_config()
|
|
|
|
|
|
def get_config_value(config_path: str):
|
|
path_parts = config_path.split(".")
|
|
cur_config = CONFIG_DATA
|
|
for key in path_parts:
|
|
if key in cur_config:
|
|
cur_config = cur_config[key]
|
|
else:
|
|
return None
|
|
return cur_config
|
|
|
|
|
|
PERSISTENT_CONFIG_REGISTRY = []
|
|
|
|
|
|
def save_config(config):
|
|
global CONFIG_DATA
|
|
global PERSISTENT_CONFIG_REGISTRY
|
|
try:
|
|
save_to_db(config)
|
|
CONFIG_DATA = config
|
|
|
|
# Trigger updates on all registered PersistentConfig entries
|
|
for config_item in PERSISTENT_CONFIG_REGISTRY:
|
|
config_item.update()
|
|
except Exception as e:
|
|
log.exception(e)
|
|
return False
|
|
return True
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class PersistentConfig(Generic[T]):
|
|
def __init__(self, env_name: str, config_path: str, env_value: T):
|
|
self.env_name = env_name
|
|
self.config_path = config_path
|
|
self.env_value = env_value
|
|
self.config_value = get_config_value(config_path)
|
|
if self.config_value is not None:
|
|
log.info(f"'{env_name}' loaded from the latest database entry")
|
|
self.value = self.config_value
|
|
else:
|
|
self.value = env_value
|
|
|
|
PERSISTENT_CONFIG_REGISTRY.append(self)
|
|
|
|
def __str__(self):
|
|
return str(self.value)
|
|
|
|
@property
|
|
def __dict__(self):
|
|
raise TypeError(
|
|
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
|
|
)
|
|
|
|
def __getattribute__(self, item):
|
|
if item == "__dict__":
|
|
raise TypeError(
|
|
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
|
|
)
|
|
return super().__getattribute__(item)
|
|
|
|
def update(self):
|
|
new_value = get_config_value(self.config_path)
|
|
if new_value is not None:
|
|
self.value = new_value
|
|
log.info(f"Updated {self.env_name} to new value {self.value}")
|
|
|
|
def save(self):
|
|
log.info(f"Saving '{self.env_name}' to the database")
|
|
path_parts = self.config_path.split(".")
|
|
sub_config = CONFIG_DATA
|
|
for key in path_parts[:-1]:
|
|
if key not in sub_config:
|
|
sub_config[key] = {}
|
|
sub_config = sub_config[key]
|
|
sub_config[path_parts[-1]] = self.value
|
|
save_to_db(CONFIG_DATA)
|
|
self.config_value = self.value
|
|
|
|
|
|
class AppConfig:
|
|
_state: dict[str, PersistentConfig]
|
|
|
|
def __init__(self):
|
|
super().__setattr__("_state", {})
|
|
|
|
def __setattr__(self, key, value):
|
|
if isinstance(value, PersistentConfig):
|
|
self._state[key] = value
|
|
else:
|
|
self._state[key].value = value
|
|
self._state[key].save()
|
|
|
|
def __getattr__(self, key):
|
|
return self._state[key].value
|
|
|
|
|
|
####################################
|
|
# WEBUI_AUTH (Required for security)
|
|
####################################
|
|
|
|
JWT_EXPIRES_IN = PersistentConfig(
|
|
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
|
)
|
|
|
|
####################################
|
|
# OAuth config
|
|
####################################
|
|
|
|
ENABLE_OAUTH_SIGNUP = PersistentConfig(
|
|
"ENABLE_OAUTH_SIGNUP",
|
|
"oauth.enable_signup",
|
|
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
|
|
)
|
|
|
|
OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig(
|
|
"OAUTH_MERGE_ACCOUNTS_BY_EMAIL",
|
|
"oauth.merge_accounts_by_email",
|
|
os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true",
|
|
)
|
|
|
|
OAUTH_PROVIDERS = {}
|
|
|
|
GOOGLE_CLIENT_ID = PersistentConfig(
|
|
"GOOGLE_CLIENT_ID",
|
|
"oauth.google.client_id",
|
|
os.environ.get("GOOGLE_CLIENT_ID", ""),
|
|
)
|
|
|
|
GOOGLE_CLIENT_SECRET = PersistentConfig(
|
|
"GOOGLE_CLIENT_SECRET",
|
|
"oauth.google.client_secret",
|
|
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
|
|
)
|
|
|
|
GOOGLE_OAUTH_SCOPE = PersistentConfig(
|
|
"GOOGLE_OAUTH_SCOPE",
|
|
"oauth.google.scope",
|
|
os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
|
|
)
|
|
|
|
GOOGLE_REDIRECT_URI = PersistentConfig(
|
|
"GOOGLE_REDIRECT_URI",
|
|
"oauth.google.redirect_uri",
|
|
os.environ.get("GOOGLE_REDIRECT_URI", ""),
|
|
)
|
|
|
|
MICROSOFT_CLIENT_ID = PersistentConfig(
|
|
"MICROSOFT_CLIENT_ID",
|
|
"oauth.microsoft.client_id",
|
|
os.environ.get("MICROSOFT_CLIENT_ID", ""),
|
|
)
|
|
|
|
MICROSOFT_CLIENT_SECRET = PersistentConfig(
|
|
"MICROSOFT_CLIENT_SECRET",
|
|
"oauth.microsoft.client_secret",
|
|
os.environ.get("MICROSOFT_CLIENT_SECRET", ""),
|
|
)
|
|
|
|
MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
|
|
"MICROSOFT_CLIENT_TENANT_ID",
|
|
"oauth.microsoft.tenant_id",
|
|
os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
|
|
)
|
|
|
|
MICROSOFT_OAUTH_SCOPE = PersistentConfig(
|
|
"MICROSOFT_OAUTH_SCOPE",
|
|
"oauth.microsoft.scope",
|
|
os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
|
|
)
|
|
|
|
MICROSOFT_REDIRECT_URI = PersistentConfig(
|
|
"MICROSOFT_REDIRECT_URI",
|
|
"oauth.microsoft.redirect_uri",
|
|
os.environ.get("MICROSOFT_REDIRECT_URI", ""),
|
|
)
|
|
|
|
OAUTH_CLIENT_ID = PersistentConfig(
|
|
"OAUTH_CLIENT_ID",
|
|
"oauth.oidc.client_id",
|
|
os.environ.get("OAUTH_CLIENT_ID", ""),
|
|
)
|
|
|
|
OAUTH_CLIENT_SECRET = PersistentConfig(
|
|
"OAUTH_CLIENT_SECRET",
|
|
"oauth.oidc.client_secret",
|
|
os.environ.get("OAUTH_CLIENT_SECRET", ""),
|
|
)
|
|
|
|
OPENID_PROVIDER_URL = PersistentConfig(
|
|
"OPENID_PROVIDER_URL",
|
|
"oauth.oidc.provider_url",
|
|
os.environ.get("OPENID_PROVIDER_URL", ""),
|
|
)
|
|
|
|
OPENID_REDIRECT_URI = PersistentConfig(
|
|
"OPENID_REDIRECT_URI",
|
|
"oauth.oidc.redirect_uri",
|
|
os.environ.get("OPENID_REDIRECT_URI", ""),
|
|
)
|
|
|
|
OAUTH_SCOPES = PersistentConfig(
|
|
"OAUTH_SCOPES",
|
|
"oauth.oidc.scopes",
|
|
os.environ.get("OAUTH_SCOPES", "openid email profile"),
|
|
)
|
|
|
|
OAUTH_PROVIDER_NAME = PersistentConfig(
|
|
"OAUTH_PROVIDER_NAME",
|
|
"oauth.oidc.provider_name",
|
|
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
|
|
)
|
|
|
|
OAUTH_USERNAME_CLAIM = PersistentConfig(
|
|
"OAUTH_USERNAME_CLAIM",
|
|
"oauth.oidc.username_claim",
|
|
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
|
|
)
|
|
|
|
OAUTH_PICTURE_CLAIM = PersistentConfig(
|
|
"OAUTH_USERNAME_CLAIM",
|
|
"oauth.oidc.avatar_claim",
|
|
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
|
|
)
|
|
|
|
OAUTH_EMAIL_CLAIM = PersistentConfig(
|
|
"OAUTH_EMAIL_CLAIM",
|
|
"oauth.oidc.email_claim",
|
|
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
|
|
)
|
|
|
|
|
|
def load_oauth_providers():
|
|
OAUTH_PROVIDERS.clear()
|
|
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
|
|
OAUTH_PROVIDERS["google"] = {
|
|
"client_id": GOOGLE_CLIENT_ID.value,
|
|
"client_secret": GOOGLE_CLIENT_SECRET.value,
|
|
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
|
|
"scope": GOOGLE_OAUTH_SCOPE.value,
|
|
"redirect_uri": GOOGLE_REDIRECT_URI.value,
|
|
}
|
|
|
|
if (
|
|
MICROSOFT_CLIENT_ID.value
|
|
and MICROSOFT_CLIENT_SECRET.value
|
|
and MICROSOFT_CLIENT_TENANT_ID.value
|
|
):
|
|
OAUTH_PROVIDERS["microsoft"] = {
|
|
"client_id": MICROSOFT_CLIENT_ID.value,
|
|
"client_secret": MICROSOFT_CLIENT_SECRET.value,
|
|
"server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
|
|
"scope": MICROSOFT_OAUTH_SCOPE.value,
|
|
"redirect_uri": MICROSOFT_REDIRECT_URI.value,
|
|
}
|
|
|
|
if (
|
|
OAUTH_CLIENT_ID.value
|
|
and OAUTH_CLIENT_SECRET.value
|
|
and OPENID_PROVIDER_URL.value
|
|
):
|
|
OAUTH_PROVIDERS["oidc"] = {
|
|
"client_id": OAUTH_CLIENT_ID.value,
|
|
"client_secret": OAUTH_CLIENT_SECRET.value,
|
|
"server_metadata_url": OPENID_PROVIDER_URL.value,
|
|
"scope": OAUTH_SCOPES.value,
|
|
"name": OAUTH_PROVIDER_NAME.value,
|
|
"redirect_uri": OPENID_REDIRECT_URI.value,
|
|
}
|
|
|
|
|
|
load_oauth_providers()
|
|
|
|
####################################
|
|
# Static DIR
|
|
####################################
|
|
|
|
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
|
|
|
|
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
|
|
|
|
if frontend_favicon.exists():
|
|
try:
|
|
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
|
except Exception as e:
|
|
logging.error(f"An error occurred: {e}")
|
|
else:
|
|
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
|
|
|
|
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
|
|
|
|
if frontend_splash.exists():
|
|
try:
|
|
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
|
|
except Exception as e:
|
|
logging.error(f"An error occurred: {e}")
|
|
else:
|
|
logging.warning(f"Frontend splash not found at {frontend_splash}")
|
|
|
|
|
|
####################################
|
|
# CUSTOM_NAME
|
|
####################################
|
|
|
|
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
|
|
|
|
if CUSTOM_NAME:
|
|
try:
|
|
r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}")
|
|
data = r.json()
|
|
if r.ok:
|
|
if "logo" in data:
|
|
WEBUI_FAVICON_URL = url = (
|
|
f"https://api.openwebui.com{data['logo']}"
|
|
if data["logo"][0] == "/"
|
|
else data["logo"]
|
|
)
|
|
|
|
r = requests.get(url, stream=True)
|
|
if r.status_code == 200:
|
|
with open(f"{STATIC_DIR}/favicon.png", "wb") as f:
|
|
r.raw.decode_content = True
|
|
shutil.copyfileobj(r.raw, f)
|
|
|
|
if "splash" in data:
|
|
url = (
|
|
f"https://api.openwebui.com{data['splash']}"
|
|
if data["splash"][0] == "/"
|
|
else data["splash"]
|
|
)
|
|
|
|
r = requests.get(url, stream=True)
|
|
if r.status_code == 200:
|
|
with open(f"{STATIC_DIR}/splash.png", "wb") as f:
|
|
r.raw.decode_content = True
|
|
shutil.copyfileobj(r.raw, f)
|
|
|
|
WEBUI_NAME = data["name"]
|
|
except Exception as e:
|
|
log.exception(e)
|
|
pass
|
|
|
|
|
|
####################################
|
|
# File Upload DIR
|
|
####################################
|
|
|
|
UPLOAD_DIR = f"{DATA_DIR}/uploads"
|
|
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
####################################
|
|
# Cache DIR
|
|
####################################
|
|
|
|
CACHE_DIR = f"{DATA_DIR}/cache"
|
|
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
####################################
|
|
# Tools DIR
|
|
####################################
|
|
|
|
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
|
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
####################################
|
|
# Functions DIR
|
|
####################################
|
|
|
|
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
|
|
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
####################################
|
|
# OLLAMA_BASE_URL
|
|
####################################
|
|
|
|
|
|
ENABLE_OLLAMA_API = PersistentConfig(
|
|
"ENABLE_OLLAMA_API",
|
|
"ollama.enable",
|
|
os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true",
|
|
)
|
|
|
|
OLLAMA_API_BASE_URL = os.environ.get(
|
|
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
|
|
)
|
|
|
|
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
|
|
|
|
K8S_FLAG = os.environ.get("K8S_FLAG", "")
|
|
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
|
|
|
|
if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
|
|
OLLAMA_BASE_URL = (
|
|
OLLAMA_API_BASE_URL[:-4]
|
|
if OLLAMA_API_BASE_URL.endswith("/api")
|
|
else OLLAMA_API_BASE_URL
|
|
)
|
|
|
|
if ENV == "prod":
|
|
if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG:
|
|
if USE_OLLAMA_DOCKER.lower() == "true":
|
|
# if you use all-in-one docker container (Open WebUI + Ollama)
|
|
# with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434
|
|
OLLAMA_BASE_URL = "http://localhost:11434"
|
|
else:
|
|
OLLAMA_BASE_URL = "http://host.docker.internal:11434"
|
|
elif K8S_FLAG:
|
|
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
|
|
|
|
|
|
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
|
|
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
|
|
|
|
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
|
|
OLLAMA_BASE_URLS = PersistentConfig(
|
|
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
|
|
)
|
|
|
|
####################################
|
|
# OPENAI_API
|
|
####################################
|
|
|
|
|
|
ENABLE_OPENAI_API = PersistentConfig(
|
|
"ENABLE_OPENAI_API",
|
|
"openai.enable",
|
|
os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true",
|
|
)
|
|
|
|
|
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
|
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
|
|
|
|
|
if OPENAI_API_BASE_URL == "":
|
|
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
|
|
|
OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
|
|
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
|
|
|
|
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
|
|
OPENAI_API_KEYS = PersistentConfig(
|
|
"OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
|
|
)
|
|
|
|
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
|
|
OPENAI_API_BASE_URLS = (
|
|
OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
|
|
)
|
|
|
|
OPENAI_API_BASE_URLS = [
|
|
url.strip() if url != "" else "https://api.openai.com/v1"
|
|
for url in OPENAI_API_BASE_URLS.split(";")
|
|
]
|
|
OPENAI_API_BASE_URLS = PersistentConfig(
|
|
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
|
|
)
|
|
|
|
OPENAI_API_KEY = ""
|
|
|
|
try:
|
|
OPENAI_API_KEY = OPENAI_API_KEYS.value[
|
|
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
|
|
]
|
|
except Exception:
|
|
pass
|
|
|
|
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
|
|
|
####################################
|
|
# WEBUI
|
|
####################################
|
|
|
|
ENABLE_SIGNUP = PersistentConfig(
|
|
"ENABLE_SIGNUP",
|
|
"ui.enable_signup",
|
|
(
|
|
False
|
|
if not WEBUI_AUTH
|
|
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
|
|
),
|
|
)
|
|
|
|
ENABLE_LOGIN_FORM = PersistentConfig(
|
|
"ENABLE_LOGIN_FORM",
|
|
"ui.ENABLE_LOGIN_FORM",
|
|
os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
|
|
)
|
|
|
|
DEFAULT_LOCALE = PersistentConfig(
|
|
"DEFAULT_LOCALE",
|
|
"ui.default_locale",
|
|
os.environ.get("DEFAULT_LOCALE", ""),
|
|
)
|
|
|
|
DEFAULT_MODELS = PersistentConfig(
|
|
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
|
|
)
|
|
|
|
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
|
|
"DEFAULT_PROMPT_SUGGESTIONS",
|
|
"ui.prompt_suggestions",
|
|
[
|
|
{
|
|
"title": ["Help me study", "vocabulary for a college entrance exam"],
|
|
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
|
|
},
|
|
{
|
|
"title": ["Give me ideas", "for what to do with my kids' art"],
|
|
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.",
|
|
},
|
|
{
|
|
"title": ["Tell me a fun fact", "about the Roman Empire"],
|
|
"content": "Tell me a random fun fact about the Roman Empire",
|
|
},
|
|
{
|
|
"title": ["Show me a code snippet", "of a website's sticky header"],
|
|
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
|
|
},
|
|
{
|
|
"title": [
|
|
"Explain options trading",
|
|
"if I'm familiar with buying and selling stocks",
|
|
],
|
|
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
|
|
},
|
|
{
|
|
"title": ["Overcome procrastination", "give me tips"],
|
|
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
|
|
},
|
|
],
|
|
)
|
|
|
|
DEFAULT_USER_ROLE = PersistentConfig(
|
|
"DEFAULT_USER_ROLE",
|
|
"ui.default_user_role",
|
|
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
|
)
|
|
|
|
USER_PERMISSIONS_CHAT_DELETION = (
|
|
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
|
|
)
|
|
|
|
USER_PERMISSIONS_CHAT_EDITING = (
|
|
os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true"
|
|
)
|
|
|
|
USER_PERMISSIONS_CHAT_TEMPORARY = (
|
|
os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true"
|
|
)
|
|
|
|
USER_PERMISSIONS = PersistentConfig(
|
|
"USER_PERMISSIONS",
|
|
"ui.user_permissions",
|
|
{
|
|
"chat": {
|
|
"deletion": USER_PERMISSIONS_CHAT_DELETION,
|
|
"editing": USER_PERMISSIONS_CHAT_EDITING,
|
|
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
|
|
}
|
|
},
|
|
)
|
|
|
|
ENABLE_MODEL_FILTER = PersistentConfig(
|
|
"ENABLE_MODEL_FILTER",
|
|
"model_filter.enable",
|
|
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
|
|
)
|
|
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
|
MODEL_FILTER_LIST = PersistentConfig(
|
|
"MODEL_FILTER_LIST",
|
|
"model_filter.list",
|
|
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
|
|
)
|
|
|
|
WEBHOOK_URL = PersistentConfig(
|
|
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
|
|
)
|
|
|
|
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
|
|
|
ENABLE_ADMIN_CHAT_ACCESS = (
|
|
os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true"
|
|
)
|
|
|
|
ENABLE_COMMUNITY_SHARING = PersistentConfig(
|
|
"ENABLE_COMMUNITY_SHARING",
|
|
"ui.enable_community_sharing",
|
|
os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
|
|
)
|
|
|
|
ENABLE_MESSAGE_RATING = PersistentConfig(
|
|
"ENABLE_MESSAGE_RATING",
|
|
"ui.enable_message_rating",
|
|
os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true",
|
|
)
|
|
|
|
|
|
def validate_cors_origins(origins):
|
|
for origin in origins:
|
|
if origin != "*":
|
|
validate_cors_origin(origin)
|
|
|
|
|
|
def validate_cors_origin(origin):
|
|
parsed_url = urlparse(origin)
|
|
|
|
# Check if the scheme is either http or https
|
|
if parsed_url.scheme not in ["http", "https"]:
|
|
raise ValueError(
|
|
f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed."
|
|
)
|
|
|
|
# Ensure that the netloc (domain + port) is present, indicating it's a valid URL
|
|
if not parsed_url.netloc:
|
|
raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.")
|
|
|
|
|
|
# For production, you should only need one host as
|
|
# fastapi serves the svelte-kit built frontend and backend from the same host and port.
|
|
# To test CORS_ALLOW_ORIGIN locally, you can set something like
|
|
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
|
|
# in your .env file depending on your frontend port, 5173 in this case.
|
|
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
|
|
|
|
if "*" in CORS_ALLOW_ORIGIN:
|
|
log.warning(
|
|
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
|
|
)
|
|
|
|
validate_cors_origins(CORS_ALLOW_ORIGIN)
|
|
|
|
|
|
class BannerModel(BaseModel):
|
|
id: str
|
|
type: str
|
|
title: Optional[str] = None
|
|
content: str
|
|
dismissible: bool
|
|
timestamp: int
|
|
|
|
|
|
try:
|
|
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
|
|
banners = [BannerModel(**banner) for banner in banners]
|
|
except Exception as e:
|
|
print(f"Error loading WEBUI_BANNERS: {e}")
|
|
banners = []
|
|
|
|
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
|
|
|
|
|
|
SHOW_ADMIN_DETAILS = PersistentConfig(
|
|
"SHOW_ADMIN_DETAILS",
|
|
"auth.admin.show",
|
|
os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
|
|
)
|
|
|
|
ADMIN_EMAIL = PersistentConfig(
|
|
"ADMIN_EMAIL",
|
|
"auth.admin.email",
|
|
os.environ.get("ADMIN_EMAIL", None),
|
|
)
|
|
|
|
|
|
####################################
|
|
# TASKS
|
|
####################################
|
|
|
|
|
|
TASK_MODEL = PersistentConfig(
|
|
"TASK_MODEL",
|
|
"task.model.default",
|
|
os.environ.get("TASK_MODEL", ""),
|
|
)
|
|
|
|
TASK_MODEL_EXTERNAL = PersistentConfig(
|
|
"TASK_MODEL_EXTERNAL",
|
|
"task.model.external",
|
|
os.environ.get("TASK_MODEL_EXTERNAL", ""),
|
|
)
|
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
|
"task.title.prompt_template",
|
|
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
|
|
)
|
|
|
|
ENABLE_SEARCH_QUERY = PersistentConfig(
|
|
"ENABLE_SEARCH_QUERY",
|
|
"task.search.enable",
|
|
os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
|
|
)
|
|
|
|
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
|
"task.search.prompt_template",
|
|
os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
|
|
)
|
|
|
|
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
|
"task.tools.prompt_template",
|
|
os.environ.get("TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", ""),
|
|
)
|
|
|
|
|
|
####################################
|
|
# Vector Database
|
|
####################################
|
|
|
|
VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
|
|
|
|
# Chroma
|
|
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
|
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
|
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
|
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
|
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
|
# Comma-separated list of header=value pairs
|
|
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
|
if CHROMA_HTTP_HEADERS:
|
|
CHROMA_HTTP_HEADERS = dict(
|
|
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
|
)
|
|
else:
|
|
CHROMA_HTTP_HEADERS = None
|
|
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
|
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
|
|
|
# Milvus
|
|
|
|
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
|
|
|
# Qdrant
|
|
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
|
|
|
####################################
|
|
# Information Retrieval (RAG)
|
|
####################################
|
|
|
|
# RAG Content Extraction
|
|
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
|
"CONTENT_EXTRACTION_ENGINE",
|
|
"rag.CONTENT_EXTRACTION_ENGINE",
|
|
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
|
|
)
|
|
|
|
TIKA_SERVER_URL = PersistentConfig(
|
|
"TIKA_SERVER_URL",
|
|
"rag.tika_server_url",
|
|
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
|
|
)
|
|
|
|
RAG_TOP_K = PersistentConfig(
|
|
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
|
|
)
|
|
RAG_RELEVANCE_THRESHOLD = PersistentConfig(
|
|
"RAG_RELEVANCE_THRESHOLD",
|
|
"rag.relevance_threshold",
|
|
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
|
|
)
|
|
|
|
ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
|
|
"ENABLE_RAG_HYBRID_SEARCH",
|
|
"rag.enable_hybrid_search",
|
|
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
|
)
|
|
|
|
RAG_FILE_MAX_COUNT = PersistentConfig(
|
|
"RAG_FILE_MAX_COUNT",
|
|
"rag.file.max_count",
|
|
(
|
|
int(os.environ.get("RAG_FILE_MAX_COUNT"))
|
|
if os.environ.get("RAG_FILE_MAX_COUNT")
|
|
else None
|
|
),
|
|
)
|
|
|
|
RAG_FILE_MAX_SIZE = PersistentConfig(
|
|
"RAG_FILE_MAX_SIZE",
|
|
"rag.file.max_size",
|
|
(
|
|
int(os.environ.get("RAG_FILE_MAX_SIZE"))
|
|
if os.environ.get("RAG_FILE_MAX_SIZE")
|
|
else None
|
|
),
|
|
)
|
|
|
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
|
|
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
|
|
"rag.enable_web_loader_ssl_verification",
|
|
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
|
|
)
|
|
|
|
RAG_EMBEDDING_ENGINE = PersistentConfig(
|
|
"RAG_EMBEDDING_ENGINE",
|
|
"rag.embedding_engine",
|
|
os.environ.get("RAG_EMBEDDING_ENGINE", ""),
|
|
)
|
|
|
|
PDF_EXTRACT_IMAGES = PersistentConfig(
|
|
"PDF_EXTRACT_IMAGES",
|
|
"rag.pdf_extract_images",
|
|
os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
|
|
)
|
|
|
|
RAG_EMBEDDING_MODEL = PersistentConfig(
|
|
"RAG_EMBEDDING_MODEL",
|
|
"rag.embedding_model",
|
|
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
|
|
)
|
|
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
|
|
|
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
|
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
|
)
|
|
|
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
|
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
|
)
|
|
|
|
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
|
|
"RAG_EMBEDDING_BATCH_SIZE",
|
|
"rag.embedding_batch_size",
|
|
int(
|
|
os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
|
|
or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
|
|
),
|
|
)
|
|
|
|
RAG_RERANKING_MODEL = PersistentConfig(
|
|
"RAG_RERANKING_MODEL",
|
|
"rag.reranking_model",
|
|
os.environ.get("RAG_RERANKING_MODEL", ""),
|
|
)
|
|
if RAG_RERANKING_MODEL.value != "":
|
|
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
|
|
|
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
|
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
|
)
|
|
|
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
|
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
|
)
|
|
|
|
|
|
RAG_TEXT_SPLITTER = PersistentConfig(
|
|
"RAG_TEXT_SPLITTER",
|
|
"rag.text_splitter",
|
|
os.environ.get("RAG_TEXT_SPLITTER", ""),
|
|
)
|
|
|
|
|
|
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
|
|
TIKTOKEN_ENCODING_NAME = PersistentConfig(
|
|
"TIKTOKEN_ENCODING_NAME",
|
|
"rag.tiktoken_encoding_name",
|
|
os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"),
|
|
)
|
|
|
|
|
|
CHUNK_SIZE = PersistentConfig(
|
|
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000"))
|
|
)
|
|
CHUNK_OVERLAP = PersistentConfig(
|
|
"CHUNK_OVERLAP",
|
|
"rag.chunk_overlap",
|
|
int(os.environ.get("CHUNK_OVERLAP", "100")),
|
|
)
|
|
|
|
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
|
|
|
|
<context>
|
|
{{CONTEXT}}
|
|
</context>
|
|
|
|
<rules>
|
|
- If you don't know, just say so.
|
|
- If you are not sure, ask for clarification.
|
|
- Answer in the same language as the user query.
|
|
- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
|
|
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
|
|
- Answer directly and without using xml tags.
|
|
</rules>
|
|
|
|
<user_query>
|
|
{{QUERY}}
|
|
</user_query>
|
|
"""
|
|
|
|
RAG_TEMPLATE = PersistentConfig(
|
|
"RAG_TEMPLATE",
|
|
"rag.template",
|
|
os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
|
|
)
|
|
|
|
RAG_OPENAI_API_BASE_URL = PersistentConfig(
|
|
"RAG_OPENAI_API_BASE_URL",
|
|
"rag.openai_api_base_url",
|
|
os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
|
)
|
|
RAG_OPENAI_API_KEY = PersistentConfig(
|
|
"RAG_OPENAI_API_KEY",
|
|
"rag.openai_api_key",
|
|
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
|
)
|
|
|
|
ENABLE_RAG_LOCAL_WEB_FETCH = (
|
|
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
|
)
|
|
|
|
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
|
"YOUTUBE_LOADER_LANGUAGE",
|
|
"rag.youtube_loader_language",
|
|
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
|
|
)
|
|
|
|
|
|
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
|
|
"ENABLE_RAG_WEB_SEARCH",
|
|
"rag.web.search.enable",
|
|
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
|
|
)
|
|
|
|
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
|
"RAG_WEB_SEARCH_ENGINE",
|
|
"rag.web.search.engine",
|
|
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
|
)
|
|
|
|
# You can provide a list of your own websites to filter after performing a web search.
|
|
# This ensures the highest level of safety and reliability of the information sources.
|
|
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
|
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
|
"rag.rag.web.search.domain.filter_list",
|
|
[
|
|
# "wikipedia.com",
|
|
# "wikimedia.org",
|
|
# "wikidata.org",
|
|
],
|
|
)
|
|
|
|
SEARXNG_QUERY_URL = PersistentConfig(
|
|
"SEARXNG_QUERY_URL",
|
|
"rag.web.search.searxng_query_url",
|
|
os.getenv("SEARXNG_QUERY_URL", ""),
|
|
)
|
|
|
|
GOOGLE_PSE_API_KEY = PersistentConfig(
|
|
"GOOGLE_PSE_API_KEY",
|
|
"rag.web.search.google_pse_api_key",
|
|
os.getenv("GOOGLE_PSE_API_KEY", ""),
|
|
)
|
|
|
|
GOOGLE_PSE_ENGINE_ID = PersistentConfig(
|
|
"GOOGLE_PSE_ENGINE_ID",
|
|
"rag.web.search.google_pse_engine_id",
|
|
os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
|
|
)
|
|
|
|
BRAVE_SEARCH_API_KEY = PersistentConfig(
|
|
"BRAVE_SEARCH_API_KEY",
|
|
"rag.web.search.brave_search_api_key",
|
|
os.getenv("BRAVE_SEARCH_API_KEY", ""),
|
|
)
|
|
|
|
SERPSTACK_API_KEY = PersistentConfig(
|
|
"SERPSTACK_API_KEY",
|
|
"rag.web.search.serpstack_api_key",
|
|
os.getenv("SERPSTACK_API_KEY", ""),
|
|
)
|
|
|
|
SERPSTACK_HTTPS = PersistentConfig(
|
|
"SERPSTACK_HTTPS",
|
|
"rag.web.search.serpstack_https",
|
|
os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
|
|
)
|
|
|
|
SERPER_API_KEY = PersistentConfig(
|
|
"SERPER_API_KEY",
|
|
"rag.web.search.serper_api_key",
|
|
os.getenv("SERPER_API_KEY", ""),
|
|
)
|
|
|
|
SERPLY_API_KEY = PersistentConfig(
|
|
"SERPLY_API_KEY",
|
|
"rag.web.search.serply_api_key",
|
|
os.getenv("SERPLY_API_KEY", ""),
|
|
)
|
|
|
|
TAVILY_API_KEY = PersistentConfig(
|
|
"TAVILY_API_KEY",
|
|
"rag.web.search.tavily_api_key",
|
|
os.getenv("TAVILY_API_KEY", ""),
|
|
)
|
|
|
|
SEARCHAPI_API_KEY = PersistentConfig(
|
|
"SEARCHAPI_API_KEY",
|
|
"rag.web.search.searchapi_api_key",
|
|
os.getenv("SEARCHAPI_API_KEY", ""),
|
|
)
|
|
|
|
SEARCHAPI_ENGINE = PersistentConfig(
|
|
"SEARCHAPI_ENGINE",
|
|
"rag.web.search.searchapi_engine",
|
|
os.getenv("SEARCHAPI_ENGINE", ""),
|
|
)
|
|
|
|
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
|
"RAG_WEB_SEARCH_RESULT_COUNT",
|
|
"rag.web.search.result_count",
|
|
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
|
|
)
|
|
|
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
|
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
|
|
"rag.web.search.concurrent_requests",
|
|
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
|
)
|
|
|
|
|
|
####################################
|
|
# Transcribe
|
|
####################################
|
|
|
|
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
|
|
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
|
WHISPER_MODEL_AUTO_UPDATE = (
|
|
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
|
)
|
|
|
|
|
|
####################################
|
|
# Images
|
|
####################################
|
|
|
|
IMAGE_GENERATION_ENGINE = PersistentConfig(
|
|
"IMAGE_GENERATION_ENGINE",
|
|
"image_generation.engine",
|
|
os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
|
|
)
|
|
|
|
ENABLE_IMAGE_GENERATION = PersistentConfig(
|
|
"ENABLE_IMAGE_GENERATION",
|
|
"image_generation.enable",
|
|
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
|
|
)
|
|
AUTOMATIC1111_BASE_URL = PersistentConfig(
|
|
"AUTOMATIC1111_BASE_URL",
|
|
"image_generation.automatic1111.base_url",
|
|
os.getenv("AUTOMATIC1111_BASE_URL", ""),
|
|
)
|
|
AUTOMATIC1111_API_AUTH = PersistentConfig(
|
|
"AUTOMATIC1111_API_AUTH",
|
|
"image_generation.automatic1111.api_auth",
|
|
os.getenv("AUTOMATIC1111_API_AUTH", ""),
|
|
)
|
|
|
|
AUTOMATIC1111_CFG_SCALE = PersistentConfig(
|
|
"AUTOMATIC1111_CFG_SCALE",
|
|
"image_generation.automatic1111.cfg_scale",
|
|
(
|
|
float(os.environ.get("AUTOMATIC1111_CFG_SCALE"))
|
|
if os.environ.get("AUTOMATIC1111_CFG_SCALE")
|
|
else None
|
|
),
|
|
)
|
|
|
|
|
|
AUTOMATIC1111_SAMPLER = PersistentConfig(
|
|
"AUTOMATIC1111_SAMPLERE",
|
|
"image_generation.automatic1111.sampler",
|
|
(
|
|
os.environ.get("AUTOMATIC1111_SAMPLER")
|
|
if os.environ.get("AUTOMATIC1111_SAMPLER")
|
|
else None
|
|
),
|
|
)
|
|
|
|
AUTOMATIC1111_SCHEDULER = PersistentConfig(
|
|
"AUTOMATIC1111_SCHEDULER",
|
|
"image_generation.automatic1111.scheduler",
|
|
(
|
|
os.environ.get("AUTOMATIC1111_SCHEDULER")
|
|
if os.environ.get("AUTOMATIC1111_SCHEDULER")
|
|
else None
|
|
),
|
|
)
|
|
|
|
COMFYUI_BASE_URL = PersistentConfig(
|
|
"COMFYUI_BASE_URL",
|
|
"image_generation.comfyui.base_url",
|
|
os.getenv("COMFYUI_BASE_URL", ""),
|
|
)
|
|
|
|
COMFYUI_DEFAULT_WORKFLOW = """
|
|
{
|
|
"3": {
|
|
"inputs": {
|
|
"seed": 0,
|
|
"steps": 20,
|
|
"cfg": 8,
|
|
"sampler_name": "euler",
|
|
"scheduler": "normal",
|
|
"denoise": 1,
|
|
"model": [
|
|
"4",
|
|
0
|
|
],
|
|
"positive": [
|
|
"6",
|
|
0
|
|
],
|
|
"negative": [
|
|
"7",
|
|
0
|
|
],
|
|
"latent_image": [
|
|
"5",
|
|
0
|
|
]
|
|
},
|
|
"class_type": "KSampler",
|
|
"_meta": {
|
|
"title": "KSampler"
|
|
}
|
|
},
|
|
"4": {
|
|
"inputs": {
|
|
"ckpt_name": "model.safetensors"
|
|
},
|
|
"class_type": "CheckpointLoaderSimple",
|
|
"_meta": {
|
|
"title": "Load Checkpoint"
|
|
}
|
|
},
|
|
"5": {
|
|
"inputs": {
|
|
"width": 512,
|
|
"height": 512,
|
|
"batch_size": 1
|
|
},
|
|
"class_type": "EmptyLatentImage",
|
|
"_meta": {
|
|
"title": "Empty Latent Image"
|
|
}
|
|
},
|
|
"6": {
|
|
"inputs": {
|
|
"text": "Prompt",
|
|
"clip": [
|
|
"4",
|
|
1
|
|
]
|
|
},
|
|
"class_type": "CLIPTextEncode",
|
|
"_meta": {
|
|
"title": "CLIP Text Encode (Prompt)"
|
|
}
|
|
},
|
|
"7": {
|
|
"inputs": {
|
|
"text": "",
|
|
"clip": [
|
|
"4",
|
|
1
|
|
]
|
|
},
|
|
"class_type": "CLIPTextEncode",
|
|
"_meta": {
|
|
"title": "CLIP Text Encode (Prompt)"
|
|
}
|
|
},
|
|
"8": {
|
|
"inputs": {
|
|
"samples": [
|
|
"3",
|
|
0
|
|
],
|
|
"vae": [
|
|
"4",
|
|
2
|
|
]
|
|
},
|
|
"class_type": "VAEDecode",
|
|
"_meta": {
|
|
"title": "VAE Decode"
|
|
}
|
|
},
|
|
"9": {
|
|
"inputs": {
|
|
"filename_prefix": "ComfyUI",
|
|
"images": [
|
|
"8",
|
|
0
|
|
]
|
|
},
|
|
"class_type": "SaveImage",
|
|
"_meta": {
|
|
"title": "Save Image"
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
|
|
|
|
COMFYUI_WORKFLOW = PersistentConfig(
|
|
"COMFYUI_WORKFLOW",
|
|
"image_generation.comfyui.workflow",
|
|
os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW),
|
|
)
|
|
|
|
COMFYUI_WORKFLOW_NODES = PersistentConfig(
|
|
"COMFYUI_WORKFLOW",
|
|
"image_generation.comfyui.nodes",
|
|
[],
|
|
)
|
|
|
|
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
|
|
"IMAGES_OPENAI_API_BASE_URL",
|
|
"image_generation.openai.api_base_url",
|
|
os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
|
)
|
|
IMAGES_OPENAI_API_KEY = PersistentConfig(
|
|
"IMAGES_OPENAI_API_KEY",
|
|
"image_generation.openai.api_key",
|
|
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
|
)
|
|
|
|
IMAGE_SIZE = PersistentConfig(
|
|
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
|
)
|
|
|
|
IMAGE_STEPS = PersistentConfig(
|
|
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
|
|
)
|
|
|
|
IMAGE_GENERATION_MODEL = PersistentConfig(
|
|
"IMAGE_GENERATION_MODEL",
|
|
"image_generation.model",
|
|
os.getenv("IMAGE_GENERATION_MODEL", ""),
|
|
)
|
|
|
|
####################################
|
|
# Audio
|
|
####################################
|
|
|
|
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
|
"AUDIO_STT_OPENAI_API_BASE_URL",
|
|
"audio.stt.openai.api_base_url",
|
|
os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
|
)
|
|
|
|
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
|
|
"AUDIO_STT_OPENAI_API_KEY",
|
|
"audio.stt.openai.api_key",
|
|
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
|
|
)
|
|
|
|
AUDIO_STT_ENGINE = PersistentConfig(
|
|
"AUDIO_STT_ENGINE",
|
|
"audio.stt.engine",
|
|
os.getenv("AUDIO_STT_ENGINE", ""),
|
|
)
|
|
|
|
AUDIO_STT_MODEL = PersistentConfig(
|
|
"AUDIO_STT_MODEL",
|
|
"audio.stt.model",
|
|
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
|
|
)
|
|
|
|
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
|
"AUDIO_TTS_OPENAI_API_BASE_URL",
|
|
"audio.tts.openai.api_base_url",
|
|
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
|
)
|
|
AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
|
|
"AUDIO_TTS_OPENAI_API_KEY",
|
|
"audio.tts.openai.api_key",
|
|
os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
|
|
)
|
|
|
|
AUDIO_TTS_API_KEY = PersistentConfig(
|
|
"AUDIO_TTS_API_KEY",
|
|
"audio.tts.api_key",
|
|
os.getenv("AUDIO_TTS_API_KEY", ""),
|
|
)
|
|
|
|
AUDIO_TTS_ENGINE = PersistentConfig(
|
|
"AUDIO_TTS_ENGINE",
|
|
"audio.tts.engine",
|
|
os.getenv("AUDIO_TTS_ENGINE", ""),
|
|
)
|
|
|
|
|
|
AUDIO_TTS_MODEL = PersistentConfig(
|
|
"AUDIO_TTS_MODEL",
|
|
"audio.tts.model",
|
|
os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model
|
|
)
|
|
|
|
AUDIO_TTS_VOICE = PersistentConfig(
|
|
"AUDIO_TTS_VOICE",
|
|
"audio.tts.voice",
|
|
os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice
|
|
)
|
|
|
|
AUDIO_TTS_SPLIT_ON = PersistentConfig(
|
|
"AUDIO_TTS_SPLIT_ON",
|
|
"audio.tts.split_on",
|
|
os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"),
|
|
)
|
|
|
|
AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig(
|
|
"AUDIO_TTS_AZURE_SPEECH_REGION",
|
|
"audio.tts.azure.speech_region",
|
|
os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", "eastus"),
|
|
)
|
|
|
|
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
|
|
"AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT",
|
|
"audio.tts.azure.speech_output_format",
|
|
os.getenv(
|
|
"AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
|
|
),
|
|
)
|