diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py
index 87732d7bc..0f65a551e 100644
--- a/backend/apps/audio/main.py
+++ b/backend/apps/audio/main.py
@@ -45,6 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE,
+ AppConfig,
)
log = logging.getLogger(__name__)
@@ -59,11 +60,11 @@ app.add_middleware(
allow_headers=["*"],
)
-
-app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
-app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
-app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
-app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
+app.state.config = AppConfig()
+app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
+app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
+app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
+app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
@@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
- "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
- "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
- "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
- "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
+ "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
+ "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
}
@@ -97,17 +98,17 @@ async def update_openai_config(
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
- app.state.OPENAI_API_BASE_URL = form_data.url
- app.state.OPENAI_API_KEY = form_data.key
- app.state.OPENAI_API_MODEL = form_data.model
- app.state.OPENAI_API_VOICE = form_data.speaker
+ app.state.config.OPENAI_API_BASE_URL = form_data.url
+ app.state.config.OPENAI_API_KEY = form_data.key
+ app.state.config.OPENAI_API_MODEL = form_data.model
+ app.state.config.OPENAI_API_VOICE = form_data.speaker
return {
"status": True,
- "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
- "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
- "OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
- "OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
+ "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
+ "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
}
@@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
- headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.post(
- url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech",
+ url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
stream=True,
diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py
index f45cf0d12..1c309439d 100644
--- a/backend/apps/images/main.py
+++ b/backend/apps/images/main.py
@@ -42,6 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
+ AppConfig,
)
@@ -60,26 +61,31 @@ app.add_middleware(
allow_headers=["*"],
)
-app.state.ENGINE = IMAGE_GENERATION_ENGINE
-app.state.ENABLED = ENABLE_IMAGE_GENERATION
+app.state.config = AppConfig()
-app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
-app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
+app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
+app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
-app.state.MODEL = IMAGE_GENERATION_MODEL
+app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
+app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
+
+app.state.config.MODEL = IMAGE_GENERATION_MODEL
-app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
+app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
-app.state.IMAGE_SIZE = IMAGE_SIZE
-app.state.IMAGE_STEPS = IMAGE_STEPS
+app.state.config.IMAGE_SIZE = IMAGE_SIZE
+app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
- return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
+ return {
+ "engine": app.state.config.ENGINE,
+ "enabled": app.state.config.ENABLED,
+ }
class ConfigUpdateForm(BaseModel):
@@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
- app.state.ENGINE = form_data.engine
- app.state.ENABLED = form_data.enabled
- return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
+ app.state.config.ENGINE = form_data.engine
+ app.state.config.ENABLED = form_data.enabled
+ return {
+ "engine": app.state.config.ENGINE,
+ "enabled": app.state.config.ENABLED,
+ }
class EngineUrlUpdateForm(BaseModel):
@@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)):
return {
- "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
- "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
+ "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
+ "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
}
@@ -113,29 +122,29 @@ async def update_engine_url(
):
if form_data.AUTOMATIC1111_BASE_URL == None:
- app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
+ app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try:
r = requests.head(url)
- app.state.AUTOMATIC1111_BASE_URL = url
+ app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
- app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
+ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
- app.state.COMFYUI_BASE_URL = url
+ app.state.config.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
- "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
- "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
+ "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
+ "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True,
}
@@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
- "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
- "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
}
@@ -160,13 +169,13 @@ async def update_openai_config(
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
- app.state.OPENAI_API_BASE_URL = form_data.url
- app.state.OPENAI_API_KEY = form_data.key
+ app.state.config.OPENAI_API_BASE_URL = form_data.url
+ app.state.config.OPENAI_API_KEY = form_data.key
return {
"status": True,
- "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
- "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
}
@@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
- return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
+ return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
@app.post("/size/update")
@@ -185,9 +194,9 @@ async def update_image_size(
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
- app.state.IMAGE_SIZE = form_data.size
+ app.state.config.IMAGE_SIZE = form_data.size
return {
- "IMAGE_SIZE": app.state.IMAGE_SIZE,
+ "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"status": True,
}
else:
@@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
- return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
+ return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
@app.post("/steps/update")
@@ -211,9 +220,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
if form_data.steps >= 0:
- app.state.IMAGE_STEPS = form_data.steps
+ app.state.config.IMAGE_STEPS = form_data.steps
return {
- "IMAGE_STEPS": app.state.IMAGE_STEPS,
+ "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"status": True,
}
else:
@@ -226,14 +235,14 @@ async def update_image_size(
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
- if app.state.ENGINE == "openai":
+ if app.state.config.ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
- elif app.state.ENGINE == "comfyui":
+ elif app.state.config.ENGINE == "comfyui":
- r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
+ r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
@@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else:
r = requests.get(
- url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
)
models = r.json()
return list(
@@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)):
)
)
except Exception as e:
- app.state.ENABLED = False
+ app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
- if app.state.ENGINE == "openai":
- return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
- elif app.state.ENGINE == "comfyui":
- return {"model": app.state.MODEL if app.state.MODEL else ""}
+ if app.state.config.ENGINE == "openai":
+ return {
+ "model": (
+ app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
+ )
+ }
+ elif app.state.config.ENGINE == "comfyui":
+ return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else:
- r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+ r = requests.get(
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
+ )
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
- app.state.ENABLED = False
+ app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
- if app.state.ENGINE == "openai":
- app.state.MODEL = model
- return app.state.MODEL
- if app.state.ENGINE == "comfyui":
- app.state.MODEL = model
- return app.state.MODEL
+ if app.state.config.ENGINE in ["openai", "comfyui"]:
+ app.state.config.MODEL = model
+ return app.state.config.MODEL
else:
- r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
+ r = requests.get(
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
+ )
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
- url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+ json=options,
)
return options
@@ -382,26 +397,32 @@ def generate_image(
user=Depends(get_current_user),
):
- width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
+ width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
r = None
try:
- if app.state.ENGINE == "openai":
+ if app.state.config.ENGINE == "openai":
headers = {}
- headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
data = {
- "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
+ "model": (
+ app.state.config.MODEL
+ if app.state.config.MODEL != ""
+ else "dall-e-2"
+ ),
"prompt": form_data.prompt,
"n": form_data.n,
- "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
+ "size": (
+ form_data.size if form_data.size else app.state.config.IMAGE_SIZE
+ ),
"response_format": "b64_json",
}
r = requests.post(
- url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
+ url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
@@ -421,7 +442,7 @@ def generate_image(
return images
- elif app.state.ENGINE == "comfyui":
+ elif app.state.config.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
@@ -430,19 +451,19 @@ def generate_image(
"n": form_data.n,
}
- if app.state.IMAGE_STEPS != None:
- data["steps"] = app.state.IMAGE_STEPS
+ if app.state.config.IMAGE_STEPS is not None:
+ data["steps"] = app.state.config.IMAGE_STEPS
- if form_data.negative_prompt != None:
+ if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
- app.state.MODEL,
+ app.state.config.MODEL,
data,
user.id,
- app.state.COMFYUI_BASE_URL,
+ app.state.config.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
@@ -469,14 +490,14 @@ def generate_image(
"height": height,
}
- if app.state.IMAGE_STEPS != None:
- data["steps"] = app.state.IMAGE_STEPS
+ if app.state.config.IMAGE_STEPS is not None:
+ data["steps"] = app.state.config.IMAGE_STEPS
- if form_data.negative_prompt != None:
+ if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(
- url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py
index 042d0336d..cb80eeed2 100644
--- a/backend/apps/ollama/main.py
+++ b/backend/apps/ollama/main.py
@@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
UPLOAD_DIR,
+ AppConfig,
)
from utils.misc import calculate_sha256
@@ -61,11 +62,12 @@ app.add_middleware(
allow_headers=["*"],
)
+app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
+app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
@@ -96,7 +98,7 @@ async def get_status():
@app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)):
- return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
+ return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
class UrlUpdateForm(BaseModel):
@@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
- app.state.OLLAMA_BASE_URLS = form_data.urls
+ app.state.config.OLLAMA_BASE_URLS = form_data.urls
- log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
- return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
+ log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
+ return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
@@ -153,7 +155,7 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
- tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
+ tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks)
models = {
@@ -186,7 +188,7 @@ async def get_ollama_tags(
return models
return models
else:
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
@@ -216,7 +218,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx == None:
# returns lowest version
- tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
+ tasks = [
+ fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
+ ]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
@@ -235,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
else:
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
@@ -267,7 +271,7 @@ class ModelNameForm(BaseModel):
async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
):
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
@@ -355,7 +359,7 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
r = None
@@ -417,7 +421,7 @@ async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
):
log.debug(f"form_data: {form_data}")
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
@@ -490,7 +494,7 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
@@ -537,7 +541,7 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
@@ -577,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
)
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
@@ -634,7 +638,7 @@ async def generate_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
@@ -684,7 +688,7 @@ def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
@@ -753,7 +757,7 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
@@ -856,7 +860,7 @@ async def generate_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
@@ -965,7 +969,7 @@ async def generate_openai_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
@@ -1064,7 +1068,7 @@ async def get_openai_models(
}
else:
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
@@ -1198,7 +1202,7 @@ async def download_model(
if url_idx == None:
url_idx = 0
- url = app.state.OLLAMA_BASE_URLS[url_idx]
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url)
@@ -1217,7 +1221,7 @@ async def download_model(
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None:
url_idx = 0
- ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
+ ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}"
@@ -1282,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
-# url = app.state.OLLAMA_BASE_URLS[url_idx]
+# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
@@ -1319,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
- url = app.state.OLLAMA_BASE_URLS[0]
+ url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
body = await request.body()
diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py
index b38c2bc2d..fb16a579b 100644
--- a/backend/apps/openai/main.py
+++ b/backend/apps/openai/main.py
@@ -26,6 +26,7 @@ from config import (
CACHE_DIR,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
+ AppConfig,
)
from typing import List, Optional
@@ -45,11 +46,13 @@ app.add_middleware(
allow_headers=["*"],
)
+app.state.config = AppConfig()
+
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
-app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
+app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
+app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
@@ -75,32 +78,32 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
- return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
+ return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models()
- app.state.OPENAI_API_BASE_URLS = form_data.urls
- return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
+ app.state.config.OPENAI_API_BASE_URLS = form_data.urls
+ return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
- return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
+ return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
- app.state.OPENAI_API_KEYS = form_data.keys
- return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
+ app.state.config.OPENAI_API_KEYS = form_data.keys
+ return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
idx = None
try:
- idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
+ idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body()
name = hashlib.sha256(body).hexdigest()
@@ -114,7 +117,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
- headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.OPENAI_API_BASE_URLS[idx]:
headers['HTTP-Referer'] = "https://openwebui.com/"
@@ -122,7 +125,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
r = None
try:
r = requests.post(
- url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
+ url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body,
headers=headers,
stream=True,
@@ -182,7 +185,8 @@ def merge_models_lists(model_lists):
[
{**model, "urlIdx": idx}
for model in models
- if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
+ if "api.openai.com"
+ not in app.state.config.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
@@ -193,12 +197,15 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
- if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
+ if (
+ len(app.state.config.OPENAI_API_KEYS) == 1
+ and app.state.config.OPENAI_API_KEYS[0] == ""
+ ):
models = {"data": []}
else:
tasks = [
- fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
- for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
+ fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
@@ -241,7 +248,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return models
return models
else:
- url = app.state.OPENAI_API_BASE_URLS[url_idx]
+ url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
r = None
@@ -305,8 +312,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
- url = app.state.OPENAI_API_BASE_URLS[idx]
- key = app.state.OPENAI_API_KEYS[idx]
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
+ key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index 2e2a8e209..d2c3964ae 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -93,6 +93,7 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
+ AppConfig,
)
from constants import ERROR_MESSAGES
@@ -102,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
-app.state.TOP_K = RAG_TOP_K
-app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
+app.state.config = AppConfig()
-app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
-app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
+app.state.config.TOP_K = RAG_TOP_K
+app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
+
+app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
+app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
-app.state.CHUNK_SIZE = CHUNK_SIZE
-app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
+app.state.config.CHUNK_SIZE = CHUNK_SIZE
+app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
-app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
-app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
-app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
-app.state.RAG_TEMPLATE = RAG_TEMPLATE
+app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
+app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
+app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
-app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
-app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
+app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
+app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
-app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
+app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
-app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
+app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
@@ -133,7 +136,7 @@ def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
- if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
+ if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
@@ -158,22 +161,22 @@ def update_reranking_model(
update_embedding_model(
- app.state.RAG_EMBEDDING_MODEL,
+ app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
- app.state.RAG_RERANKING_MODEL,
+ app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
- app.state.RAG_EMBEDDING_ENGINE,
- app.state.RAG_EMBEDDING_MODEL,
+ app.state.config.RAG_EMBEDDING_ENGINE,
+ app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
- app.state.OPENAI_API_KEY,
- app.state.OPENAI_API_BASE_URL,
+ app.state.config.OPENAI_API_KEY,
+ app.state.config.OPENAI_API_BASE_URL,
)
origins = ["*"]
@@ -200,12 +203,12 @@ class UrlForm(CollectionNameForm):
async def get_status():
return {
"status": True,
- "chunk_size": app.state.CHUNK_SIZE,
- "chunk_overlap": app.state.CHUNK_OVERLAP,
- "template": app.state.RAG_TEMPLATE,
- "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
- "embedding_model": app.state.RAG_EMBEDDING_MODEL,
- "reranking_model": app.state.RAG_RERANKING_MODEL,
+ "chunk_size": app.state.config.CHUNK_SIZE,
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
+ "template": app.state.config.RAG_TEMPLATE,
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
@@ -213,18 +216,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)):
return {
"status": True,
- "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
- "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
- "url": app.state.OPENAI_API_BASE_URL,
- "key": app.state.OPENAI_API_KEY,
+ "url": app.state.config.OPENAI_API_BASE_URL,
+ "key": app.state.config.OPENAI_API_KEY,
},
}
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
- return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
+ return {
+ "status": True,
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
+ }
class OpenAIConfigForm(BaseModel):
@@ -243,34 +249,34 @@ async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
- f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
+ f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
- app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
- app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
+ app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
+ app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
- if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
+ if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
- app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
- app.state.OPENAI_API_KEY = form_data.openai_config.key
+ app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
+ app.state.config.OPENAI_API_KEY = form_data.openai_config.key
- update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
+ update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
app.state.EMBEDDING_FUNCTION = get_embedding_function(
- app.state.RAG_EMBEDDING_ENGINE,
- app.state.RAG_EMBEDDING_MODEL,
+ app.state.config.RAG_EMBEDDING_ENGINE,
+ app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
- app.state.OPENAI_API_KEY,
- app.state.OPENAI_API_BASE_URL,
+ app.state.config.OPENAI_API_KEY,
+ app.state.config.OPENAI_API_BASE_URL,
)
return {
"status": True,
- "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
- "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
- "url": app.state.OPENAI_API_BASE_URL,
- "key": app.state.OPENAI_API_KEY,
+ "url": app.state.config.OPENAI_API_BASE_URL,
+ "key": app.state.config.OPENAI_API_KEY,
},
}
except Exception as e:
@@ -290,16 +296,16 @@ async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
- f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
+ f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
- app.state.RAG_RERANKING_MODEL = form_data.reranking_model
+ app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
- update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
+ update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
return {
"status": True,
- "reranking_model": app.state.RAG_RERANKING_MODEL,
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
@@ -313,14 +319,14 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
- "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
+ "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
- "chunk_size": app.state.CHUNK_SIZE,
- "chunk_overlap": app.state.CHUNK_OVERLAP,
+ "chunk_size": app.state.config.CHUNK_SIZE,
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
- "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+ "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
- "language": app.state.YOUTUBE_LOADER_LANGUAGE,
+ "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -345,50 +351,52 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
- app.state.PDF_EXTRACT_IMAGES = (
+ app.state.config.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
- if form_data.pdf_extract_images != None
- else app.state.PDF_EXTRACT_IMAGES
+ if form_data.pdf_extract_images is not None
+ else app.state.config.PDF_EXTRACT_IMAGES
)
- app.state.CHUNK_SIZE = (
- form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
+ app.state.config.CHUNK_SIZE = (
+ form_data.chunk.chunk_size
+ if form_data.chunk is not None
+ else app.state.config.CHUNK_SIZE
)
- app.state.CHUNK_OVERLAP = (
+ app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
- if form_data.chunk != None
- else app.state.CHUNK_OVERLAP
+ if form_data.chunk is not None
+ else app.state.config.CHUNK_OVERLAP
)
- app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
+ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
- else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
+ else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
- app.state.YOUTUBE_LOADER_LANGUAGE = (
+ app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
- if form_data.youtube != None
- else app.state.YOUTUBE_LOADER_LANGUAGE
+ if form_data.youtube is not None
+ else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
- if form_data.youtube != None
+ if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
return {
"status": True,
- "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
+ "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
- "chunk_size": app.state.CHUNK_SIZE,
- "chunk_overlap": app.state.CHUNK_OVERLAP,
+ "chunk_size": app.state.config.CHUNK_SIZE,
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
- "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+ "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
- "language": app.state.YOUTUBE_LOADER_LANGUAGE,
+ "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -398,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
- "template": app.state.RAG_TEMPLATE,
+ "template": app.state.config.RAG_TEMPLATE,
}
@@ -406,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
- "template": app.state.RAG_TEMPLATE,
- "k": app.state.TOP_K,
- "r": app.state.RELEVANCE_THRESHOLD,
- "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
+ "template": app.state.config.RAG_TEMPLATE,
+ "k": app.state.config.TOP_K,
+ "r": app.state.config.RELEVANCE_THRESHOLD,
+ "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -424,16 +432,20 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
- app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
- app.state.TOP_K = form_data.k if form_data.k else 4
- app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
- app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
+ app.state.config.RAG_TEMPLATE = (
+ form_data.template if form_data.template else RAG_TEMPLATE,
+ )
+ app.state.config.TOP_K = form_data.k if form_data.k else 4
+ app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
+ form_data.hybrid if form_data.hybrid else False,
+ )
return {
"status": True,
- "template": app.state.RAG_TEMPLATE,
- "k": app.state.TOP_K,
- "r": app.state.RELEVANCE_THRESHOLD,
- "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
+ "template": app.state.config.RAG_TEMPLATE,
+ "k": app.state.config.TOP_K,
+ "r": app.state.config.RELEVANCE_THRESHOLD,
+ "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -451,21 +463,23 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
- if app.state.ENABLE_RAG_HYBRID_SEARCH:
+ if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
- k=form_data.k if form_data.k else app.state.TOP_K,
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
- r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
+ r=(
+ form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
+ ),
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
- k=form_data.k if form_data.k else app.state.TOP_K,
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
log.exception(e)
@@ -489,21 +503,23 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
- if app.state.ENABLE_RAG_HYBRID_SEARCH:
+ if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
- k=form_data.k if form_data.k else app.state.TOP_K,
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
- r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
+ r=(
+ form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
+ ),
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
- k=form_data.k if form_data.k else app.state.TOP_K,
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
@@ -520,7 +536,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
- language=app.state.YOUTUBE_LOADER_LANGUAGE,
+ language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
)
data = loader.load()
@@ -548,7 +564,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(
- form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
+ form_data.url,
+ verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
)
data = loader.load()
@@ -604,8 +621,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=app.state.CHUNK_SIZE,
- chunk_overlap=app.state.CHUNK_OVERLAP,
+ chunk_size=app.state.config.CHUNK_SIZE,
+ chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
@@ -622,8 +639,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=app.state.CHUNK_SIZE,
- chunk_overlap=app.state.CHUNK_OVERLAP,
+ chunk_size=app.state.config.CHUNK_SIZE,
+ chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
@@ -646,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
- app.state.RAG_EMBEDDING_ENGINE,
- app.state.RAG_EMBEDDING_MODEL,
+ app.state.config.RAG_EMBEDDING_ENGINE,
+ app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
- app.state.OPENAI_API_KEY,
- app.state.OPENAI_API_BASE_URL,
+ app.state.config.OPENAI_API_KEY,
+ app.state.config.OPENAI_API_BASE_URL,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@@ -724,7 +741,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if file_ext == "pdf":
- loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
+ loader = PyPDFLoader(
+ file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
+ )
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py
index 66cdfb3d4..755e3911b 100644
--- a/backend/apps/web/main.py
+++ b/backend/apps/web/main.py
@@ -21,20 +21,24 @@ from config import (
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
+ JWT_EXPIRES_IN,
+ AppConfig,
)
app = FastAPI()
origins = ["*"]
-app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
-app.state.JWT_EXPIRES_IN = "-1"
+app.state.config = AppConfig()
-app.state.DEFAULT_MODELS = DEFAULT_MODELS
-app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
-app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
-app.state.USER_PERMISSIONS = USER_PERMISSIONS
-app.state.WEBHOOK_URL = WEBHOOK_URL
+app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
+app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
+
+app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
+app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
+app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
+app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
+app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware(
@@ -61,6 +65,6 @@ async def get_status():
return {
"status": True,
"auth": WEBUI_AUTH,
- "default_models": app.state.DEFAULT_MODELS,
- "default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,
+ "default_models": app.state.config.DEFAULT_MODELS,
+ "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py
index 9fa962dda..998e74659 100644
--- a/backend/apps/web/routers/auths.py
+++ b/backend/apps/web/routers/auths.py
@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if user:
token = create_token(
data={"id": user.id},
- expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
+ expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
return {
@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm):
- if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH:
+ if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role = (
"admin"
if Users.get_num_users() == 0
- else request.app.state.DEFAULT_USER_ROLE
+ else request.app.state.config.DEFAULT_USER_ROLE
)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
@@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if user:
token = create_token(
data={"id": user.id},
- expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
+ expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# response.set_cookie(key='token', value=token, httponly=True)
- if request.app.state.WEBHOOK_URL:
+ if request.app.state.config.WEBHOOK_URL:
post_webhook(
- request.app.state.WEBHOOK_URL,
+ request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
@@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
- return request.app.state.ENABLE_SIGNUP
+ return request.app.state.config.ENABLE_SIGNUP
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
- request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
- return request.app.state.ENABLE_SIGNUP
+ request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
+ return request.app.state.config.ENABLE_SIGNUP
############################
@@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
- return request.app.state.DEFAULT_USER_ROLE
+ return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel):
@@ -304,8 +304,8 @@ async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
- request.app.state.DEFAULT_USER_ROLE = form_data.role
- return request.app.state.DEFAULT_USER_ROLE
+ request.app.state.config.DEFAULT_USER_ROLE = form_data.role
+ return request.app.state.config.DEFAULT_USER_ROLE
############################
@@ -315,7 +315,7 @@ async def update_default_user_role(
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
- return request.app.state.JWT_EXPIRES_IN
+ return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
@@ -332,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
- request.app.state.JWT_EXPIRES_IN = form_data.duration
- return request.app.state.JWT_EXPIRES_IN
+ request.app.state.config.JWT_EXPIRES_IN = form_data.duration
+ return request.app.state.config.JWT_EXPIRES_IN
else:
- return request.app.state.JWT_EXPIRES_IN
+ return request.app.state.config.JWT_EXPIRES_IN
############################
diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py
index 0bad55a6a..143ed5e0a 100644
--- a/backend/apps/web/routers/configs.py
+++ b/backend/apps/web/routers/configs.py
@@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
):
- request.app.state.DEFAULT_MODELS = form_data.models
- return request.app.state.DEFAULT_MODELS
+ request.app.state.config.DEFAULT_MODELS = form_data.models
+ return request.app.state.config.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
@@ -55,5 +55,5 @@ async def set_global_default_suggestions(
user=Depends(get_admin_user),
):
data = form_data.model_dump()
- request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
- return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
+ request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
+ return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py
index 59f6c21b7..d87854e89 100644
--- a/backend/apps/web/routers/users.py
+++ b/backend/apps/web/routers/users.py
@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
- return request.app.state.USER_PERMISSIONS
+ return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user")
async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user)
):
- request.app.state.USER_PERMISSIONS = form_data
- return request.app.state.USER_PERMISSIONS
+ request.app.state.config.USER_PERMISSIONS = form_data
+ return request.app.state.config.USER_PERMISSIONS
############################
diff --git a/backend/config.py b/backend/config.py
index 5c6247a9f..112edba90 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -5,6 +5,7 @@ import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
+from typing import TypeVar, Generic, Union
from pathlib import Path
import json
@@ -17,7 +18,6 @@ import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
-
####################################
# Load .env file
####################################
@@ -71,7 +71,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
-
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
@@ -161,16 +160,6 @@ CHANGELOG = changelog_json
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
-####################################
-# WEBUI_AUTH (Required for security)
-####################################
-
-WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
-WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
- "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
-)
-
-
####################################
# DATA/FRONTEND BUILD DIR
####################################
@@ -184,6 +173,108 @@ try:
except:
CONFIG_DATA = {}
+
+####################################
+# Config helpers
+####################################
+
+
+def save_config():
+ try:
+ with open(f"{DATA_DIR}/config.json", "w") as f:
+ json.dump(CONFIG_DATA, f, indent="\t")
+ except Exception as e:
+ log.exception(e)
+
+
+def get_config_value(config_path: str):
+ path_parts = config_path.split(".")
+ cur_config = CONFIG_DATA
+ for key in path_parts:
+ if key in cur_config:
+ cur_config = cur_config[key]
+ else:
+ return None
+ return cur_config
+
+
+T = TypeVar("T")
+
+
+class PersistentConfig(Generic[T]):
+ def __init__(self, env_name: str, config_path: str, env_value: T):
+ self.env_name = env_name
+ self.config_path = config_path
+ self.env_value = env_value
+ self.config_value = get_config_value(config_path)
+ if self.config_value is not None:
+ log.info(f"'{env_name}' loaded from config.json")
+ self.value = self.config_value
+ else:
+ self.value = env_value
+
+ def __str__(self):
+ return str(self.value)
+
+ @property
+ def __dict__(self):
+ raise TypeError(
+ "PersistentConfig object cannot be converted to dict, use config_get or .value instead."
+ )
+
+ def __getattribute__(self, item):
+ if item == "__dict__":
+ raise TypeError(
+ "PersistentConfig object cannot be converted to dict, use config_get or .value instead."
+ )
+ return super().__getattribute__(item)
+
+ def save(self):
+ # Don't save if the value is the same as the env value and the config value
+ if self.env_value == self.value:
+ if self.config_value == self.value:
+ return
+ log.info(f"Saving '{self.env_name}' to config.json")
+ path_parts = self.config_path.split(".")
+ config = CONFIG_DATA
+ for key in path_parts[:-1]:
+ if key not in config:
+ config[key] = {}
+ config = config[key]
+ config[path_parts[-1]] = self.value
+ save_config()
+ self.config_value = self.value
+
+
+class AppConfig:
+ _state: dict[str, PersistentConfig]
+
+ def __init__(self):
+ super().__setattr__("_state", {})
+
+ def __setattr__(self, key, value):
+ if isinstance(value, PersistentConfig):
+ self._state[key] = value
+ else:
+ self._state[key].value = value
+ self._state[key].save()
+
+ def __getattr__(self, key):
+ return self._state[key].value
+
+
+####################################
+# WEBUI_AUTH (Required for security)
+####################################
+
+WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
+WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
+ "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
+)
+JWT_EXPIRES_IN = PersistentConfig(
+ "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
+)
+
####################################
# Static DIR
####################################
@@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
-
+OLLAMA_BASE_URLS = PersistentConfig(
+ "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
+)
####################################
# OPENAI_API
@@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
-
+OPENAI_API_KEYS = PersistentConfig(
+ "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
+)
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = (
@@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";")
]
+OPENAI_API_BASE_URLS = PersistentConfig(
+ "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
+)
OPENAI_API_KEY = ""
try:
- OPENAI_API_KEY = OPENAI_API_KEYS[
- OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
+ OPENAI_API_KEY = OPENAI_API_KEYS.value[
+ OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
]
except:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
-
####################################
# WEBUI
####################################
-ENABLE_SIGNUP = (
- False
- if WEBUI_AUTH == False
- else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
+ENABLE_SIGNUP = PersistentConfig(
+ "ENABLE_SIGNUP",
+ "ui.enable_signup",
+ (
+ False
+ if not WEBUI_AUTH
+ else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
+ ),
+)
+DEFAULT_MODELS = PersistentConfig(
+ "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
)
-DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
-
-DEFAULT_PROMPT_SUGGESTIONS = (
- CONFIG_DATA["ui"]["prompt_suggestions"]
- if "ui" in CONFIG_DATA
- and "prompt_suggestions" in CONFIG_DATA["ui"]
- and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
- else [
+DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
+ "DEFAULT_PROMPT_SUGGESTIONS",
+ "ui.prompt_suggestions",
+ [
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
@@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
},
- ]
+ ],
)
-
-DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
+DEFAULT_USER_ROLE = PersistentConfig(
+ "DEFAULT_USER_ROLE",
+ "ui.default_user_role",
+ os.getenv("DEFAULT_USER_ROLE", "pending"),
+)
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
)
-USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
+USER_PERMISSIONS = PersistentConfig(
+ "USER_PERMISSIONS",
+ "ui.user_permissions",
+ {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
+)
-ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
+ENABLE_MODEL_FILTER = PersistentConfig(
+ "ENABLE_MODEL_FILTER",
+ "model_filter.enable",
+ os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
+)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
-MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
+MODEL_FILTER_LIST = PersistentConfig(
+ "MODEL_FILTER_LIST",
+ "model_filter.list",
+ [model.strip() for model in MODEL_FILTER_LIST.split(";")],
+)
-WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
+WEBHOOK_URL = PersistentConfig(
+ "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
+)
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
@@ -458,26 +575,45 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
-RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
-RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
-
-ENABLE_RAG_HYBRID_SEARCH = (
- os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
+RAG_TOP_K = PersistentConfig(
+ "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
+)
+RAG_RELEVANCE_THRESHOLD = PersistentConfig(
+ "RAG_RELEVANCE_THRESHOLD",
+ "rag.relevance_threshold",
+ float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
)
-
-ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
- os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
+ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
+ "ENABLE_RAG_HYBRID_SEARCH",
+ "rag.enable_hybrid_search",
+ os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
)
-RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
-
-PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
-
-RAG_EMBEDDING_MODEL = os.environ.get(
- "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
+ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
+ "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
+ "rag.enable_web_loader_ssl_verification",
+ os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
)
-log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
+
+RAG_EMBEDDING_ENGINE = PersistentConfig(
+ "RAG_EMBEDDING_ENGINE",
+ "rag.embedding_engine",
+ os.environ.get("RAG_EMBEDDING_ENGINE", ""),
+)
+
+PDF_EXTRACT_IMAGES = PersistentConfig(
+ "PDF_EXTRACT_IMAGES",
+ "rag.pdf_extract_images",
+ os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
+)
+
+RAG_EMBEDDING_MODEL = PersistentConfig(
+ "RAG_EMBEDDING_MODEL",
+ "rag.embedding_model",
+ os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
+)
+log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
@@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
-RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
-if not RAG_RERANKING_MODEL == "":
- log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
+RAG_RERANKING_MODEL = PersistentConfig(
+ "RAG_RERANKING_MODEL",
+ "rag.reranking_model",
+ os.environ.get("RAG_RERANKING_MODEL", ""),
+)
+if RAG_RERANKING_MODEL.value != "":
+ log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
@@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true":
else:
DEVICE_TYPE = "cpu"
-
-CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
-CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
+CHUNK_SIZE = PersistentConfig(
+ "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
+)
+CHUNK_OVERLAP = PersistentConfig(
+ "CHUNK_OVERLAP",
+ "rag.chunk_overlap",
+ int(os.environ.get("CHUNK_OVERLAP", "100")),
+)
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags.
@@ -545,16 +690,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
-RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
+RAG_TEMPLATE = PersistentConfig(
+ "RAG_TEMPLATE",
+ "rag.template",
+ os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
+)
-RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
-RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
+RAG_OPENAI_API_BASE_URL = PersistentConfig(
+ "RAG_OPENAI_API_BASE_URL",
+ "rag.openai_api_base_url",
+ os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
+)
+RAG_OPENAI_API_KEY = PersistentConfig(
+ "RAG_OPENAI_API_KEY",
+ "rag.openai_api_key",
+ os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
+)
ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
-YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",")
+YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
+ "YOUTUBE_LOADER_LANGUAGE",
+ "rag.youtube_loader_language",
+ os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
+)
####################################
# Transcribe
@@ -571,34 +732,78 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images
####################################
-IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
-
-ENABLE_IMAGE_GENERATION = (
- os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
+IMAGE_GENERATION_ENGINE = PersistentConfig(
+ "IMAGE_GENERATION_ENGINE",
+ "image_generation.engine",
+ os.getenv("IMAGE_GENERATION_ENGINE", ""),
)
-AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
-COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
-
-IMAGES_OPENAI_API_BASE_URL = os.getenv(
- "IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
+ENABLE_IMAGE_GENERATION = PersistentConfig(
+ "ENABLE_IMAGE_GENERATION",
+ "image_generation.enable",
+ os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
+)
+AUTOMATIC1111_BASE_URL = PersistentConfig(
+ "AUTOMATIC1111_BASE_URL",
+ "image_generation.automatic1111.base_url",
+ os.getenv("AUTOMATIC1111_BASE_URL", ""),
)
-IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
-IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
+COMFYUI_BASE_URL = PersistentConfig(
+ "COMFYUI_BASE_URL",
+ "image_generation.comfyui.base_url",
+ os.getenv("COMFYUI_BASE_URL", ""),
+)
-IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
+IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
+ "IMAGES_OPENAI_API_BASE_URL",
+ "image_generation.openai.api_base_url",
+ os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
+)
+IMAGES_OPENAI_API_KEY = PersistentConfig(
+ "IMAGES_OPENAI_API_KEY",
+ "image_generation.openai.api_key",
+ os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
+)
-IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
+IMAGE_SIZE = PersistentConfig(
+ "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
+)
+
+IMAGE_STEPS = PersistentConfig(
+ "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
+)
+
+IMAGE_GENERATION_MODEL = PersistentConfig(
+ "IMAGE_GENERATION_MODEL",
+ "image_generation.model",
+ os.getenv("IMAGE_GENERATION_MODEL", ""),
+)
####################################
# Audio
####################################
-AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
-AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
-AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1")
-AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy")
+AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
+ "AUDIO_OPENAI_API_BASE_URL",
+ "audio.openai.api_base_url",
+ os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
+)
+AUDIO_OPENAI_API_KEY = PersistentConfig(
+ "AUDIO_OPENAI_API_KEY",
+ "audio.openai.api_key",
+ os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
+)
+AUDIO_OPENAI_API_MODEL = PersistentConfig(
+ "AUDIO_OPENAI_API_MODEL",
+ "audio.openai.api_model",
+ os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
+)
+AUDIO_OPENAI_API_VOICE = PersistentConfig(
+ "AUDIO_OPENAI_API_VOICE",
+ "audio.openai.api_voice",
+ os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
+)
####################################
# LiteLLM
diff --git a/backend/main.py b/backend/main.py
index 2d8d9ed68..8b7f9af69 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -59,6 +59,7 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
+ AppConfig,
)
from constants import ERROR_MESSAGES
@@ -107,10 +108,11 @@ app = FastAPI(
docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
)
-app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
-app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+app.state.config = AppConfig()
+app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
+app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.WEBHOOK_URL = WEBHOOK_URL
+app.state.config.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"]
@@ -250,9 +252,9 @@ async def get_app_config():
"version": VERSION,
"auth": WEBUI_AUTH,
"default_locale": default_locale,
- "images": images_app.state.ENABLED,
- "default_models": webui_app.state.DEFAULT_MODELS,
- "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
+ "images": images_app.state.config.ENABLED,
+ "default_models": webui_app.state.config.DEFAULT_MODELS,
+ "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
}
@@ -261,8 +263,8 @@ async def get_app_config():
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {
- "enabled": app.state.ENABLE_MODEL_FILTER,
- "models": app.state.MODEL_FILTER_LIST,
+ "enabled": app.state.config.ENABLE_MODEL_FILTER,
+ "models": app.state.config.MODEL_FILTER_LIST,
}
@@ -275,28 +277,28 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
- app.state.ENABLE_MODEL_FILTER = form_data.enabled
- app.state.MODEL_FILTER_LIST = form_data.models
+ app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
+ app.state.config.MODEL_FILTER_LIST, form_data.models
- ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
- ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+ ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
+ ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
- openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
- openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+ openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
+ openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
- litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
- litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
+ litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
+ litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
return {
- "enabled": app.state.ENABLE_MODEL_FILTER,
- "models": app.state.MODEL_FILTER_LIST,
+ "enabled": app.state.config.ENABLE_MODEL_FILTER,
+ "models": app.state.config.MODEL_FILTER_LIST,
}
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {
- "url": app.state.WEBHOOK_URL,
+ "url": app.state.config.WEBHOOK_URL,
}
@@ -306,12 +308,12 @@ class UrlForm(BaseModel):
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
- app.state.WEBHOOK_URL = form_data.url
+ app.state.config.WEBHOOK_URL = form_data.url
- webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
+ webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {
- "url": app.state.WEBHOOK_URL,
+ "url": app.state.config.WEBHOOK_URL,
}