Merge branch 'dev' into feat/backend-web-search

This commit is contained in:
Jun Siang Cheah
2024-05-14 14:03:23 +08:00
48 changed files with 2650 additions and 1030 deletions

View File

@@ -5,6 +5,7 @@ import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from pathlib import Path
import json
@@ -17,7 +18,6 @@ import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
@@ -71,7 +71,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
@@ -161,16 +160,6 @@ CHANGELOG = changelog_json
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
####################################
# DATA/FRONTEND BUILD DIR
####################################
@@ -184,6 +173,108 @@ try:
except:
CONFIG_DATA = {}
####################################
# Config helpers
####################################
def save_config():
try:
with open(f"{DATA_DIR}/config.json", "w") as f:
json.dump(CONFIG_DATA, f, indent="\t")
except Exception as e:
log.exception(e)
def get_config_value(config_path: str):
path_parts = config_path.split(".")
cur_config = CONFIG_DATA
for key in path_parts:
if key in cur_config:
cur_config = cur_config[key]
else:
return None
return cur_config
T = TypeVar("T")
class PersistentConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T):
self.env_name = env_name
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json")
self.value = self.config_value
else:
self.value = env_value
def __str__(self):
return str(self.value)
@property
def __dict__(self):
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def __getattribute__(self, item):
if item == "__dict__":
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return super().__getattribute__(item)
def save(self):
# Don't save if the value is the same as the env value and the config value
if self.env_value == self.value:
if self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to config.json")
path_parts = self.config_path.split(".")
config = CONFIG_DATA
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
save_config()
self.config_value = self.value
class AppConfig:
_state: dict[str, PersistentConfig]
def __init__(self):
super().__setattr__("_state", {})
def __setattr__(self, key, value):
if isinstance(value, PersistentConfig):
self._state[key] = value
else:
self._state[key].value = value
self._state[key].save()
def __getattr__(self, key):
return self._state[key].value
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
####################################
# Static DIR
####################################
@@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
OLLAMA_BASE_URLS = PersistentConfig(
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
)
####################################
# OPENAI_API
@@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_KEYS = PersistentConfig(
"OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
)
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = (
@@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";")
]
OPENAI_API_BASE_URLS = PersistentConfig(
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
)
OPENAI_API_KEY = ""
try:
OPENAI_API_KEY = OPENAI_API_KEYS[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
OPENAI_API_KEY = OPENAI_API_KEYS.value[
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
]
except:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# WEBUI
####################################
ENABLE_SIGNUP = (
False
if WEBUI_AUTH == False
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
ENABLE_SIGNUP = PersistentConfig(
"ENABLE_SIGNUP",
"ui.enable_signup",
(
False
if not WEBUI_AUTH
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
),
)
DEFAULT_MODELS = PersistentConfig(
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
)
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = (
CONFIG_DATA["ui"]["prompt_suggestions"]
if "ui" in CONFIG_DATA
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
"DEFAULT_PROMPT_SUGGESTIONS",
"ui.prompt_suggestions",
[
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
@@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
},
]
],
)
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
DEFAULT_USER_ROLE = PersistentConfig(
"DEFAULT_USER_ROLE",
"ui.default_user_role",
os.getenv("DEFAULT_USER_ROLE", "pending"),
)
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
)
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
USER_PERMISSIONS = PersistentConfig(
"USER_PERMISSIONS",
"ui.user_permissions",
{"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
)
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
ENABLE_MODEL_FILTER = PersistentConfig(
"ENABLE_MODEL_FILTER",
"model_filter.enable",
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
MODEL_FILTER_LIST = PersistentConfig(
"MODEL_FILTER_LIST",
"model_filter.list",
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
)
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
WEBHOOK_URL = PersistentConfig(
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
)
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
@@ -458,26 +575,45 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
)
RAG_RELEVANCE_THRESHOLD = PersistentConfig(
"RAG_RELEVANCE_THRESHOLD",
"rag.relevance_threshold",
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
"ENABLE_RAG_HYBRID_SEARCH",
"rag.enable_hybrid_search",
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
"rag.enable_web_loader_ssl_verification",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_ENGINE = PersistentConfig(
"RAG_EMBEDDING_ENGINE",
"rag.embedding_engine",
os.environ.get("RAG_EMBEDDING_ENGINE", ""),
)
PDF_EXTRACT_IMAGES = PersistentConfig(
"PDF_EXTRACT_IMAGES",
"rag.pdf_extract_images",
os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
)
RAG_EMBEDDING_MODEL = PersistentConfig(
"RAG_EMBEDDING_MODEL",
"rag.embedding_model",
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
@@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if not RAG_RERANKING_MODEL == "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
RAG_RERANKING_MODEL = PersistentConfig(
"RAG_RERANKING_MODEL",
"rag.reranking_model",
os.environ.get("RAG_RERANKING_MODEL", ""),
)
if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
@@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true":
else:
DEVICE_TYPE = "cpu"
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
CHUNK_SIZE = PersistentConfig(
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
)
CHUNK_OVERLAP = PersistentConfig(
"CHUNK_OVERLAP",
"rag.chunk_overlap",
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
@@ -545,16 +690,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",
"rag.template",
os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
RAG_OPENAI_API_BASE_URL = PersistentConfig(
"RAG_OPENAI_API_BASE_URL",
"rag.openai_api_base_url",
os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
RAG_OPENAI_API_KEY = PersistentConfig(
"RAG_OPENAI_API_KEY",
"rag.openai_api_key",
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",")
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
"YOUTUBE_LOADER_LANGUAGE",
"rag.youtube_loader_language",
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")
@@ -590,34 +751,78 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images
####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
ENABLE_IMAGE_GENERATION = (
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
IMAGE_GENERATION_ENGINE = PersistentConfig(
"IMAGE_GENERATION_ENGINE",
"image_generation.engine",
os.getenv("IMAGE_GENERATION_ENGINE", ""),
)
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
ENABLE_IMAGE_GENERATION = PersistentConfig(
"ENABLE_IMAGE_GENERATION",
"image_generation.enable",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
)
AUTOMATIC1111_BASE_URL = PersistentConfig(
"AUTOMATIC1111_BASE_URL",
"image_generation.automatic1111.base_url",
os.getenv("AUTOMATIC1111_BASE_URL", ""),
)
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
COMFYUI_BASE_URL = PersistentConfig(
"COMFYUI_BASE_URL",
"image_generation.comfyui.base_url",
os.getenv("COMFYUI_BASE_URL", ""),
)
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
"IMAGES_OPENAI_API_BASE_URL",
"image_generation.openai.api_base_url",
os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
IMAGES_OPENAI_API_KEY = PersistentConfig(
"IMAGES_OPENAI_API_KEY",
"image_generation.openai.api_key",
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
IMAGE_STEPS = PersistentConfig(
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
)
IMAGE_GENERATION_MODEL = PersistentConfig(
"IMAGE_GENERATION_MODEL",
"image_generation.model",
os.getenv("IMAGE_GENERATION_MODEL", ""),
)
####################################
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1")
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy")
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL",
"audio.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
)
####################################
# LiteLLM