feat: save UI config changes to config.json

This commit is contained in:
Jun Siang Cheah 2024-05-10 13:36:10 +08:00
parent 9a95767062
commit 058eb76568
11 changed files with 611 additions and 336 deletions

View File

@ -45,6 +45,8 @@ from config import (
AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE, AUDIO_OPENAI_API_VOICE,
config_get,
config_set,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -83,10 +85,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/config") @app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL),
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE),
} }
@ -97,17 +99,22 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
app.state.OPENAI_API_KEY = form_data.key config_set(app.state.OPENAI_API_KEY, form_data.key)
app.state.OPENAI_API_MODEL = form_data.model config_set(app.state.OPENAI_API_MODEL, form_data.model)
app.state.OPENAI_API_VOICE = form_data.speaker config_set(app.state.OPENAI_API_VOICE, form_data.speaker)
app.state.OPENAI_API_BASE_URL.save()
app.state.OPENAI_API_KEY.save()
app.state.OPENAI_API_MODEL.save()
app.state.OPENAI_API_VOICE.save()
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL),
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE),
} }

View File

@ -42,6 +42,8 @@ from config import (
IMAGE_GENERATION_MODEL, IMAGE_GENERATION_MODEL,
IMAGE_SIZE, IMAGE_SIZE,
IMAGE_STEPS, IMAGE_STEPS,
config_get,
config_set,
) )
@ -79,7 +81,10 @@ app.state.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config") @app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {
"engine": config_get(app.state.ENGINE),
"enabled": config_get(app.state.ENABLED),
}
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
@ -89,9 +94,12 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine config_set(app.state.ENGINE, form_data.engine)
app.state.ENABLED = form_data.enabled config_set(app.state.ENABLED, form_data.enabled)
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {
"engine": config_get(app.state.ENGINE),
"enabled": config_get(app.state.ENABLED),
}
class EngineUrlUpdateForm(BaseModel): class EngineUrlUpdateForm(BaseModel):
@ -102,8 +110,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url") @app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
} }
@ -113,29 +121,29 @@ async def update_engine_url(
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL))
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url config_set(app.state.AUTOMATIC1111_BASE_URL, url)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None: if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL)
else: else:
url = form_data.COMFYUI_BASE_URL.strip("/") url = form_data.COMFYUI_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.COMFYUI_BASE_URL = url config_set(app.state.COMFYUI_BASE_URL, url)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
"status": True, "status": True,
} }
@ -148,8 +156,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config") @app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
} }
@ -160,13 +168,13 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
app.state.OPENAI_API_KEY = form_data.key config_set(app.state.OPENAI_API_KEY, form_data.key)
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
} }
@ -176,7 +184,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size") @app.get("/size")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE} return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)}
@app.post("/size/update") @app.post("/size/update")
@ -185,9 +193,9 @@ async def update_image_size(
): ):
pattern = r"^\d+x\d+$" # Regular expression pattern pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size): if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size config_set(app.state.IMAGE_SIZE, form_data.size)
return { return {
"IMAGE_SIZE": app.state.IMAGE_SIZE, "IMAGE_SIZE": config_get(app.state.IMAGE_SIZE),
"status": True, "status": True,
} }
else: else:
@ -203,7 +211,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps") @app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS} return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)}
@app.post("/steps/update") @app.post("/steps/update")
@ -211,9 +219,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.steps >= 0: if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps config_set(app.state.IMAGE_STEPS, form_data.steps)
return { return {
"IMAGE_STEPS": app.state.IMAGE_STEPS, "IMAGE_STEPS": config_get(app.state.IMAGE_STEPS),
"status": True, "status": True,
} }
else: else:
@ -263,15 +271,25 @@ def get_models(user=Depends(get_current_user)):
async def get_default_model(user=Depends(get_admin_user)): async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} return {
"model": (
config_get(app.state.MODEL)
if config_get(app.state.MODEL)
else "dall-e-2"
)
}
elif app.state.ENGINE == "comfyui": elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""} return {
"model": (
config_get(app.state.MODEL) if config_get(app.state.MODEL) else ""
)
}
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json() options = r.json()
return {"model": options["sd_model_checkpoint"]} return {"model": options["sd_model_checkpoint"]}
except Exception as e: except Exception as e:
app.state.ENABLED = False config_set(app.state.ENABLED, False)
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -280,12 +298,9 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE == "openai": if app.state.ENGINE in ["openai", "comfyui"]:
app.state.MODEL = model config_set(app.state.MODEL, model)
return app.state.MODEL return config_get(app.state.MODEL)
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json() options = r.json()
@ -382,7 +397,7 @@ def generate_image(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x")))
r = None r = None
try: try:
@ -396,7 +411,11 @@ def generate_image(
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
"prompt": form_data.prompt, "prompt": form_data.prompt,
"n": form_data.n, "n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE, "size": (
form_data.size
if form_data.size
else config_get(app.state.IMAGE_SIZE)
),
"response_format": "b64_json", "response_format": "b64_json",
} }
@ -430,19 +449,19 @@ def generate_image(
"n": form_data.n, "n": form_data.n,
} }
if app.state.IMAGE_STEPS != None: if config_get(app.state.IMAGE_STEPS) is not None:
data["steps"] = app.state.IMAGE_STEPS data["steps"] = config_get(app.state.IMAGE_STEPS)
if form_data.negative_prompt != None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data) data = ImageGenerationPayload(**data)
res = comfyui_generate_image( res = comfyui_generate_image(
app.state.MODEL, config_get(app.state.MODEL),
data, data,
user.id, user.id,
app.state.COMFYUI_BASE_URL, config_get(app.state.COMFYUI_BASE_URL),
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
@ -469,10 +488,10 @@ def generate_image(
"height": height, "height": height,
} }
if app.state.IMAGE_STEPS != None: if config_get(app.state.IMAGE_STEPS) is not None:
data["steps"] = app.state.IMAGE_STEPS data["steps"] = config_get(app.state.IMAGE_STEPS)
if form_data.negative_prompt != None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
r = requests.post( r = requests.post(

View File

@ -46,6 +46,8 @@ from config import (
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
config_set,
config_get,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
@ -96,7 +98,7 @@ async def get_status():
@app.get("/urls") @app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)): async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)}
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/urls/update") @app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls config_set(app.state.OLLAMA_BASE_URLS, form_data.urls)
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)}
@app.get("/cancel/{request_id}") @app.get("/cancel/{request_id}")
@ -153,7 +155,9 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] tasks = [
fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
models = { models = {
@ -179,14 +183,15 @@ async def get_ollama_tags(
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST, lambda model: model["name"]
in config_get(app.state.MODEL_FILTER_LIST),
models["models"], models["models"],
) )
) )
return models return models
return models return models
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
@ -216,7 +221,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
# returns lowest version # returns lowest version
tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] tasks = [
fetch_url(f"{url}/api/version")
for url in config_get(app.state.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
@ -235,7 +243,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
) )
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/version") r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status() r.raise_for_status()
@ -267,7 +275,7 @@ class ModelNameForm(BaseModel):
async def pull_model( async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -355,7 +363,7 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None r = None
@ -417,7 +425,7 @@ async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -490,7 +498,7 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -537,7 +545,7 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -577,7 +585,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
) )
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -634,7 +642,7 @@ async def generate_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -684,7 +692,7 @@ def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
@ -753,7 +761,7 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -856,7 +864,7 @@ async def generate_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -965,7 +973,7 @@ async def generate_openai_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
@ -1064,7 +1072,7 @@ async def get_openai_models(
} }
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
@ -1198,7 +1206,7 @@ async def download_model(
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx] url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
file_name = parse_huggingface_url(form_data.url) file_name = parse_huggingface_url(form_data.url)
@ -1217,7 +1225,7 @@ async def download_model(
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}" file_path = f"{UPLOAD_DIR}/{file.filename}"
@ -1282,7 +1290,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None: # if url_idx == None:
# url_idx = 0 # url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx] # url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename) # file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size # total_size = file.size
@ -1319,7 +1327,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async def deprecated_proxy( async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user) path: str, request: Request, user=Depends(get_verified_user)
): ):
url = app.state.OLLAMA_BASE_URLS[0] url = config_get(app.state.OLLAMA_BASE_URLS)[0]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
body = await request.body() body = await request.body()

View File

@ -26,6 +26,8 @@ from config import (
CACHE_DIR, CACHE_DIR,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
config_set,
config_get,
) )
from typing import List, Optional from typing import List, Optional
@ -75,32 +77,34 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls") @app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)): async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
@app.post("/urls/update") @app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models() await get_all_models()
app.state.OPENAI_API_BASE_URLS = form_data.urls config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls)
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
@app.get("/keys") @app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)): async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
@app.post("/keys/update") @app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys config_set(app.state.OPENAI_API_KEYS, form_data.keys)
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None idx = None
try: try:
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") idx = config_get(app.state.OPENAI_API_BASE_URLS).index(
"https://api.openai.com/v1"
)
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
@ -114,13 +118,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" headers["Authorization"] = (
f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}"
)
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
@ -180,7 +186,8 @@ def merge_models_lists(model_lists):
[ [
{**model, "urlIdx": idx} {**model, "urlIdx": idx}
for model in models for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] if "api.openai.com"
not in config_get(app.state.OPENAI_API_BASE_URLS)[idx]
or "gpt" in model["id"] or "gpt" in model["id"]
] ]
) )
@ -191,12 +198,15 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": if (
len(config_get(app.state.OPENAI_API_KEYS)) == 1
and config_get(app.state.OPENAI_API_KEYS)[0] == ""
):
models = {"data": []} models = {"data": []}
else: else:
tasks = [ tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS))
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
@ -228,18 +238,19 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if config_get(app.state.ENABLE_MODEL_FILTER):
if user.role == "user": if user.role == "user":
models["data"] = list( models["data"] = list(
filter( filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST, lambda model: model["id"]
in config_get(app.state.MODEL_FILTER_LIST),
models["data"], models["data"],
) )
) )
return models return models
return models return models
else: else:
url = app.state.OPENAI_API_BASE_URLS[url_idx] url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx]
r = None r = None
@ -303,8 +314,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx] url = config_get(app.state.OPENAI_API_BASE_URLS)[idx]
key = app.state.OPENAI_API_KEYS[idx] key = config_get(app.state.OPENAI_API_KEYS)[idx]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"

View File

@ -93,6 +93,8 @@ from config import (
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
config_set,
config_get,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -133,7 +135,7 @@ def update_embedding_model(
embedding_model: str, embedding_model: str,
update_model: bool = False, update_model: bool = False,
): ):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model), get_model_path(embedding_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,
@ -158,22 +160,22 @@ def update_reranking_model(
update_embedding_model( update_embedding_model(
app.state.RAG_EMBEDDING_MODEL, config_get(app.state.RAG_EMBEDDING_MODEL),
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
) )
update_reranking_model( update_reranking_model(
app.state.RAG_RERANKING_MODEL, config_get(app.state.RAG_RERANKING_MODEL),
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, config_get(app.state.RAG_EMBEDDING_ENGINE),
app.state.RAG_EMBEDDING_MODEL, config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, config_get(app.state.OPENAI_API_KEY),
app.state.OPENAI_API_BASE_URL, config_get(app.state.OPENAI_API_BASE_URL),
) )
origins = ["*"] origins = ["*"]
@ -200,12 +202,12 @@ class UrlForm(CollectionNameForm):
async def get_status(): async def get_status():
return { return {
"status": True, "status": True,
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
"template": app.state.RAG_TEMPLATE, "template": config_get(app.state.RAG_TEMPLATE),
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"reranking_model": app.state.RAG_RERANKING_MODEL, "reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
} }
@ -213,18 +215,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)): async def get_embedding_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"openai_config": { "openai_config": {
"url": app.state.OPENAI_API_BASE_URL, "url": config_get(app.state.OPENAI_API_BASE_URL),
"key": app.state.OPENAI_API_KEY, "key": config_get(app.state.OPENAI_API_KEY),
}, },
} }
@app.get("/reranking") @app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)): async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} return {
"status": True,
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
}
class OpenAIConfigForm(BaseModel): class OpenAIConfigForm(BaseModel):
@ -246,31 +251,31 @@ async def update_embedding_config(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model)
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]:
if form_data.openai_config != None: if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url)
app.state.OPENAI_API_KEY = form_data.openai_config.key config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key)
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True)
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, config_get(app.state.RAG_EMBEDDING_ENGINE),
app.state.RAG_EMBEDDING_MODEL, config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, config_get(app.state.OPENAI_API_KEY),
app.state.OPENAI_API_BASE_URL, config_get(app.state.OPENAI_API_BASE_URL),
) )
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"openai_config": { "openai_config": {
"url": app.state.OPENAI_API_BASE_URL, "url": config_get(app.state.OPENAI_API_BASE_URL),
"key": app.state.OPENAI_API_KEY, "key": config_get(app.state.OPENAI_API_KEY),
}, },
} }
except Exception as e: except Exception as e:
@ -293,13 +298,13 @@ async def update_reranking_config(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
) )
try: try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model)
update_reranking_model(app.state.RAG_RERANKING_MODEL, True) update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True)
return { return {
"status": True, "status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL, "reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating reranking model: {e}") log.exception(f"Problem updating reranking model: {e}")
@ -313,14 +318,16 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)): async def get_rag_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
"chunk": { "chunk": {
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "web_loader_ssl_verification": config_get(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
),
"youtube": { "youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE, "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
} }
@ -345,50 +352,69 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = ( config_set(
form_data.pdf_extract_images app.state.PDF_EXTRACT_IMAGES,
if form_data.pdf_extract_images != None (
else app.state.PDF_EXTRACT_IMAGES form_data.pdf_extract_images
if form_data.pdf_extract_images is not None
else config_get(app.state.PDF_EXTRACT_IMAGES)
),
) )
app.state.CHUNK_SIZE = ( config_set(
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE app.state.CHUNK_SIZE,
(
form_data.chunk.chunk_size
if form_data.chunk is not None
else config_get(app.state.CHUNK_SIZE)
),
) )
app.state.CHUNK_OVERLAP = ( config_set(
form_data.chunk.chunk_overlap app.state.CHUNK_OVERLAP,
if form_data.chunk != None (
else app.state.CHUNK_OVERLAP form_data.chunk.chunk_overlap
if form_data.chunk is not None
else config_get(app.state.CHUNK_OVERLAP)
),
) )
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( config_set(
form_data.web_loader_ssl_verification app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
if form_data.web_loader_ssl_verification != None (
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION)
),
) )
app.state.YOUTUBE_LOADER_LANGUAGE = ( config_set(
form_data.youtube.language app.state.YOUTUBE_LOADER_LANGUAGE,
if form_data.youtube != None (
else app.state.YOUTUBE_LOADER_LANGUAGE form_data.youtube.language
if form_data.youtube is not None
else config_get(app.state.YOUTUBE_LOADER_LANGUAGE)
),
) )
app.state.YOUTUBE_LOADER_TRANSLATION = ( app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation form_data.youtube.translation
if form_data.youtube != None if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION else app.state.YOUTUBE_LOADER_TRANSLATION
) )
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
"chunk": { "chunk": {
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "web_loader_ssl_verification": config_get(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
),
"youtube": { "youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE, "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
} }
@ -398,7 +424,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)): async def get_rag_template(user=Depends(get_current_user)):
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": config_get(app.state.RAG_TEMPLATE),
} }
@ -406,10 +432,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)): async def get_query_settings(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": config_get(app.state.RAG_TEMPLATE),
"k": app.state.TOP_K, "k": config_get(app.state.TOP_K),
"r": app.state.RELEVANCE_THRESHOLD, "r": config_get(app.state.RELEVANCE_THRESHOLD),
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
} }
@ -424,16 +450,22 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings( async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user) form_data: QuerySettingsForm, user=Depends(get_admin_user)
): ):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE config_set(
app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RAG_TEMPLATE,
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 form_data.template if form_data.template else RAG_TEMPLATE,
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False )
config_set(app.state.TOP_K, form_data.k if form_data.k else 4)
config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0)
config_set(
app.state.ENABLE_RAG_HYBRID_SEARCH,
form_data.hybrid if form_data.hybrid else False,
)
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": config_get(app.state.RAG_TEMPLATE),
"k": app.state.TOP_K, "k": config_get(app.state.TOP_K),
"r": app.state.RELEVANCE_THRESHOLD, "r": config_get(app.state.RELEVANCE_THRESHOLD),
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
} }
@ -451,21 +483,25 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
return query_doc_with_hybrid_search( return query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else config_get(app.state.TOP_K),
reranking_function=app.state.sentence_transformer_rf, reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=(
form_data.r
if form_data.r
else config_get(app.state.RELEVANCE_THRESHOLD)
),
) )
else: else:
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else config_get(app.state.TOP_K),
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -489,21 +525,25 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
return query_collection_with_hybrid_search( return query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else config_get(app.state.TOP_K),
reranking_function=app.state.sentence_transformer_rf, reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=(
form_data.r
if form_data.r
else config_get(app.state.RELEVANCE_THRESHOLD)
),
) )
else: else:
return query_collection( return query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else config_get(app.state.TOP_K),
) )
except Exception as e: except Exception as e:
@ -520,8 +560,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url( loader = YoutubeLoader.from_youtube_url(
form_data.url, form_data.url,
add_video_info=True, add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE, language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
translation=app.state.YOUTUBE_LOADER_TRANSLATION, translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION),
) )
data = loader.load() data = loader.load()
@ -548,7 +588,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader( loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION form_data.url,
verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION),
) )
data = loader.load() data = loader.load()
@ -604,8 +645,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=config_get(app.state.CHUNK_SIZE),
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
add_start_index=True, add_start_index=True,
) )
@ -622,8 +663,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False text, metadata, collection_name, overwrite: bool = False
) -> bool: ) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=config_get(app.state.CHUNK_SIZE),
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.create_documents([text], metadatas=[metadata]) docs = text_splitter.create_documents([text], metadatas=[metadata])
@ -646,11 +687,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function( embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, config_get(app.state.RAG_EMBEDDING_ENGINE),
app.state.RAG_EMBEDDING_MODEL, config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, config_get(app.state.OPENAI_API_KEY),
app.state.OPENAI_API_BASE_URL, config_get(app.state.OPENAI_API_BASE_URL),
) )
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@ -724,7 +765,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
] ]
if file_ext == "pdf": if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) loader = PyPDFLoader(
file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES)
)
elif file_ext == "csv": elif file_ext == "csv":
loader = CSVLoader(file_path) loader = CSVLoader(file_path)
elif file_ext == "rst": elif file_ext == "rst":

View File

@ -21,6 +21,8 @@ from config import (
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
config_get,
) )
app = FastAPI() app = FastAPI()
@ -28,7 +30,7 @@ app = FastAPI()
origins = ["*"] origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.JWT_EXPIRES_IN = "-1" app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
@ -61,6 +63,6 @@ async def get_status():
return { return {
"status": True, "status": True,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS, "default_models": config_get(app.state.DEFAULT_MODELS),
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS),
} }

View File

@ -33,7 +33,7 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set
router = APIRouter() router = APIRouter()
@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)),
) )
return { return {
@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH: if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if Users.get_num_users() == 0
else request.app.state.DEFAULT_USER_ROLE else config_get(request.app.state.DEFAULT_USER_ROLE)
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
@ -194,13 +194,15 @@ async def signup(request: Request, form_data: SignupForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), expires_delta=parse_duration(
config_get(request.app.state.JWT_EXPIRES_IN)
),
) )
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL: if config_get(request.app.state.WEBHOOK_URL):
post_webhook( post_webhook(
request.app.state.WEBHOOK_URL, config_get(request.app.state.WEBHOOK_URL),
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
"action": "signup", "action": "signup",
@ -276,13 +278,15 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@router.get("/signup/enabled", response_model=bool) @router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.ENABLE_SIGNUP return config_get(request.app.state.ENABLE_SIGNUP)
@router.get("/signup/enabled/toggle", response_model=bool) @router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP config_set(
return request.app.state.ENABLE_SIGNUP request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP)
)
return config_get(request.app.state.ENABLE_SIGNUP)
############################ ############################
@ -292,7 +296,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@router.get("/signup/user/role") @router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)): async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.DEFAULT_USER_ROLE return config_get(request.app.state.DEFAULT_USER_ROLE)
class UpdateRoleForm(BaseModel): class UpdateRoleForm(BaseModel):
@ -304,8 +308,8 @@ async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
): ):
if form_data.role in ["pending", "user", "admin"]: if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role)
return request.app.state.DEFAULT_USER_ROLE return config_get(request.app.state.DEFAULT_USER_ROLE)
############################ ############################
@ -315,7 +319,7 @@ async def update_default_user_role(
@router.get("/token/expires") @router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN return config_get(request.app.state.JWT_EXPIRES_IN)
class UpdateJWTExpiresDurationForm(BaseModel): class UpdateJWTExpiresDurationForm(BaseModel):
@ -332,10 +336,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern # Check if the input string matches the pattern
if re.match(pattern, form_data.duration): if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration)
return request.app.state.JWT_EXPIRES_IN return config_get(request.app.state.JWT_EXPIRES_IN)
else: else:
return request.app.state.JWT_EXPIRES_IN return config_get(request.app.state.JWT_EXPIRES_IN)
############################ ############################

View File

@ -9,6 +9,7 @@ import time
import uuid import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from config import config_set, config_get
from utils.utils import ( from utils.utils import (
get_password_hash, get_password_hash,
@ -44,8 +45,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async def set_global_default_models( async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
): ):
request.app.state.DEFAULT_MODELS = form_data.models config_set(request.app.state.DEFAULT_MODELS, form_data.models)
return request.app.state.DEFAULT_MODELS return config_get(request.app.state.DEFAULT_MODELS)
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
@ -55,5 +56,5 @@ async def set_global_default_suggestions(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
data = form_data.model_dump() data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"])
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS)

View File

@ -15,7 +15,7 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS, config_set, config_get
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@router.get("/permissions/user") @router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)): async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.USER_PERMISSIONS return config_get(request.app.state.USER_PERMISSIONS)
@router.post("/permissions/user") @router.post("/permissions/user")
async def update_user_permissions( async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user) request: Request, form_data: dict, user=Depends(get_admin_user)
): ):
request.app.state.USER_PERMISSIONS = form_data config_set(request.app.state.USER_PERMISSIONS, form_data)
return request.app.state.USER_PERMISSIONS return config_get(request.app.state.USER_PERMISSIONS)
############################ ############################

View File

@ -5,6 +5,7 @@ import chromadb
from chromadb import Settings from chromadb import Settings
from base64 import b64encode from base64 import b64encode
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from pathlib import Path from pathlib import Path
import json import json
@ -17,7 +18,6 @@ import shutil
from secrets import token_bytes from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### ####################################
# Load .env file # Load .env file
#################################### ####################################
@ -29,7 +29,6 @@ try:
except ImportError: except ImportError:
print("dotenv not installed, skipping...") print("dotenv not installed, skipping...")
#################################### ####################################
# LOGGING # LOGGING
#################################### ####################################
@ -71,7 +70,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"]) log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI": if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)" WEBUI_NAME += " (Open WebUI)"
@ -80,7 +78,6 @@ WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)
#################################### ####################################
@ -151,26 +148,14 @@ for version in soup.find_all("h2"):
changelog_json[version_number] = version_data changelog_json[version_number] = version_data
CHANGELOG = changelog_json CHANGELOG = changelog_json
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION
#################################### ####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
#################################### ####################################
# DATA/FRONTEND BUILD DIR # DATA/FRONTEND BUILD DIR
#################################### ####################################
@ -184,6 +169,93 @@ try:
except: except:
CONFIG_DATA = {} CONFIG_DATA = {}
####################################
# Config helpers
####################################
def save_config():
try:
with open(f"{DATA_DIR}/config.json", "w") as f:
json.dump(CONFIG_DATA, f, indent="\t")
except Exception as e:
log.exception(e)
def get_config_value(config_path: str):
path_parts = config_path.split(".")
cur_config = CONFIG_DATA
for key in path_parts:
if key in cur_config:
cur_config = cur_config[key]
else:
return None
return cur_config
T = TypeVar("T")
class WrappedConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T):
self.env_name = env_name
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json")
self.value = self.config_value
else:
self.value = env_value
def __str__(self):
return str(self.value)
def save(self):
# Don't save if the value is the same as the env value and the config value
if self.env_value == self.value:
if self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to config.json")
path_parts = self.config_path.split(".")
config = CONFIG_DATA
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
save_config()
self.config_value = self.value
def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True):
if isinstance(config, WrappedConfig):
config.value = value
if save_config:
config.save()
else:
config = value
def config_get(config: Union[WrappedConfig[T], T]) -> T:
if isinstance(config, WrappedConfig):
return config.value
return config
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
JWT_EXPIRES_IN = WrappedConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
#################################### ####################################
# Static DIR # Static DIR
#################################### ####################################
@ -225,7 +297,6 @@ if CUSTOM_NAME:
log.exception(e) log.exception(e)
pass pass
#################################### ####################################
# File Upload DIR # File Upload DIR
#################################### ####################################
@ -233,7 +304,6 @@ if CUSTOM_NAME:
UPLOAD_DIR = f"{DATA_DIR}/uploads" UPLOAD_DIR = f"{DATA_DIR}/uploads"
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# Cache DIR # Cache DIR
#################################### ####################################
@ -241,7 +311,6 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache" CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# Docs DIR # Docs DIR
#################################### ####################################
@ -282,7 +351,6 @@ if not os.path.exists(LITELLM_CONFIG_PATH):
create_config_file(LITELLM_CONFIG_PATH) create_config_file(LITELLM_CONFIG_PATH)
log.info("Config file created successfully.") log.info("Config file created successfully.")
#################################### ####################################
# OLLAMA_BASE_URL # OLLAMA_BASE_URL
#################################### ####################################
@ -313,12 +381,13 @@ if ENV == "prod":
elif K8S_FLAG: elif K8S_FLAG:
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
OLLAMA_BASE_URLS = WrappedConfig(
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
)
#################################### ####################################
# OPENAI_API # OPENAI_API
@ -327,7 +396,6 @@ OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "": if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
@ -335,7 +403,7 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_KEYS = WrappedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS)
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = ( OPENAI_API_BASE_URLS = (
@ -346,37 +414,42 @@ OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1" url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";") for url in OPENAI_API_BASE_URLS.split(";")
] ]
OPENAI_API_BASE_URLS = WrappedConfig(
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
)
OPENAI_API_KEY = "" OPENAI_API_KEY = ""
try: try:
OPENAI_API_KEY = OPENAI_API_KEYS[ OPENAI_API_KEY = OPENAI_API_KEYS.value[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
] ]
except: except:
pass pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
#################################### ####################################
# WEBUI # WEBUI
#################################### ####################################
ENABLE_SIGNUP = ( ENABLE_SIGNUP = WrappedConfig(
False "ENABLE_SIGNUP",
if WEBUI_AUTH == False "ui.enable_signup",
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" (
False
if not WEBUI_AUTH
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
),
)
DEFAULT_MODELS = WrappedConfig(
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
) )
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = WrappedConfig(
DEFAULT_PROMPT_SUGGESTIONS = ( "DEFAULT_PROMPT_SUGGESTIONS",
CONFIG_DATA["ui"]["prompt_suggestions"] "ui.prompt_suggestions",
if "ui" in CONFIG_DATA [
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
{ {
"title": ["Help me study", "vocabulary for a college entrance exam"], "title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
@ -404,23 +477,42 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Overcome procrastination", "give me tips"], "title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
}, },
] ],
) )
DEFAULT_USER_ROLE = WrappedConfig(
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") "DEFAULT_USER_ROLE",
"ui.default_user_role",
USER_PERMISSIONS_CHAT_DELETION = ( os.getenv("DEFAULT_USER_ROLE", "pending"),
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
) )
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} USER_PERMISSIONS_CHAT_DELETION = WrappedConfig(
"USER_PERMISSIONS_CHAT_DELETION",
"ui.user_permissions.chat.deletion",
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true",
)
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true" USER_PERMISSIONS = WrappedConfig(
"USER_PERMISSIONS",
"ui.user_permissions",
{"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
)
ENABLE_MODEL_FILTER = WrappedConfig(
"ENABLE_MODEL_FILTER",
"model_filter.enable",
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] MODEL_FILTER_LIST = WrappedConfig(
"MODEL_FILTER_LIST",
"model_filter.list",
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
)
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") WEBHOOK_URL = WrappedConfig(
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
)
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
@ -458,26 +550,45 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_TOP_K = WrappedConfig(
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
)
ENABLE_RAG_HYBRID_SEARCH = ( RAG_RELEVANCE_THRESHOLD = WrappedConfig(
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" "RAG_RELEVANCE_THRESHOLD",
"rag.relevance_threshold",
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
) )
ENABLE_RAG_HYBRID_SEARCH = WrappedConfig(
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( "ENABLE_RAG_HYBRID_SEARCH",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true" "rag.enable_hybrid_search",
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
) )
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = WrappedConfig(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" "rag.enable_web_loader_ssl_verification",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
) )
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_ENGINE = WrappedConfig(
"RAG_EMBEDDING_ENGINE",
"rag.embedding_engine",
os.environ.get("RAG_EMBEDDING_ENGINE", ""),
)
PDF_EXTRACT_IMAGES = WrappedConfig(
"PDF_EXTRACT_IMAGES",
"rag.pdf_extract_images",
os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
)
RAG_EMBEDDING_MODEL = WrappedConfig(
"RAG_EMBEDDING_MODEL",
"rag.embedding_model",
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
@ -487,9 +598,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") RAG_RERANKING_MODEL = WrappedConfig(
if not RAG_RERANKING_MODEL == "": "RAG_RERANKING_MODEL",
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), "rag.reranking_model",
os.environ.get("RAG_RERANKING_MODEL", ""),
)
if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = ( RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
@ -499,7 +614,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
if CHROMA_HTTP_HOST != "": if CHROMA_HTTP_HOST != "":
CHROMA_CLIENT = chromadb.HttpClient( CHROMA_CLIENT = chromadb.HttpClient(
host=CHROMA_HTTP_HOST, host=CHROMA_HTTP_HOST,
@ -518,7 +632,6 @@ else:
database=CHROMA_DATABASE, database=CHROMA_DATABASE,
) )
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
@ -527,9 +640,14 @@ if USE_CUDA.lower() == "true":
else: else:
DEVICE_TYPE = "cpu" DEVICE_TYPE = "cpu"
CHUNK_SIZE = WrappedConfig(
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) )
CHUNK_OVERLAP = WrappedConfig(
"CHUNK_OVERLAP",
"rag.chunk_overlap",
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
@ -545,16 +663,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query. Given the context information, answer the query.
Query: [query]""" Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) RAG_TEMPLATE = WrappedConfig(
"RAG_TEMPLATE",
"rag.template",
os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = WrappedConfig(
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) "RAG_OPENAI_API_BASE_URL",
"rag.openai_api_base_url",
os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
RAG_OPENAI_API_KEY = WrappedConfig(
"RAG_OPENAI_API_KEY",
"rag.openai_api_key",
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
ENABLE_RAG_LOCAL_WEB_FETCH = ( ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
) )
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",") YOUTUBE_LOADER_LANGUAGE = WrappedConfig(
"YOUTUBE_LOADER_LANGUAGE",
"rag.youtube_loader_language",
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
#################################### ####################################
# Transcribe # Transcribe
@ -566,39 +700,82 @@ WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
) )
#################################### ####################################
# Images # Images
#################################### ####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "") IMAGE_GENERATION_ENGINE = WrappedConfig(
"IMAGE_GENERATION_ENGINE",
ENABLE_IMAGE_GENERATION = ( "image_generation.engine",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true" os.getenv("IMAGE_GENERATION_ENGINE", ""),
) )
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") ENABLE_IMAGE_GENERATION = WrappedConfig(
"ENABLE_IMAGE_GENERATION",
IMAGES_OPENAI_API_BASE_URL = os.getenv( "image_generation.enable",
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
)
AUTOMATIC1111_BASE_URL = WrappedConfig(
"AUTOMATIC1111_BASE_URL",
"image_generation.automatic1111.base_url",
os.getenv("AUTOMATIC1111_BASE_URL", ""),
) )
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512") COMFYUI_BASE_URL = WrappedConfig(
"COMFYUI_BASE_URL",
"image_generation.comfyui.base_url",
os.getenv("COMFYUI_BASE_URL", ""),
)
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50)) IMAGES_OPENAI_API_BASE_URL = WrappedConfig(
"IMAGES_OPENAI_API_BASE_URL",
"image_generation.openai.api_base_url",
os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
IMAGES_OPENAI_API_KEY = WrappedConfig(
"IMAGES_OPENAI_API_KEY",
"image_generation.openai.api_key",
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "") IMAGE_SIZE = WrappedConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
IMAGE_STEPS = WrappedConfig(
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
)
IMAGE_GENERATION_MODEL = WrappedConfig(
"IMAGE_GENERATION_MODEL",
"image_generation.model",
os.getenv("IMAGE_GENERATION_MODEL", ""),
)
#################################### ####################################
# Audio # Audio
#################################### ####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) AUDIO_OPENAI_API_BASE_URL = WrappedConfig(
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) "AUDIO_OPENAI_API_BASE_URL",
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1") "audio.openai.api_base_url",
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy") os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = WrappedConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = WrappedConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
)
AUDIO_OPENAI_API_VOICE = WrappedConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
)
#################################### ####################################
# LiteLLM # LiteLLM
@ -612,7 +789,6 @@ if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT") raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
#################################### ####################################
# Database # Database
#################################### ####################################

View File

@ -58,6 +58,8 @@ from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
config_get,
config_set,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -243,9 +245,11 @@ async def get_app_config():
"version": VERSION, "version": VERSION,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_locale": default_locale, "default_locale": default_locale,
"images": images_app.state.ENABLED, "images": config_get(images_app.state.ENABLED),
"default_models": webui_app.state.DEFAULT_MODELS, "default_models": config_get(webui_app.state.DEFAULT_MODELS),
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": config_get(
webui_app.state.DEFAULT_PROMPT_SUGGESTIONS
),
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT, "admin_export_enabled": ENABLE_ADMIN_EXPORT,
} }
@ -254,8 +258,8 @@ async def get_app_config():
@app.get("/api/config/model/filter") @app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)): async def get_model_filter_config(user=Depends(get_admin_user)):
return { return {
"enabled": app.state.ENABLE_MODEL_FILTER, "enabled": config_get(app.state.ENABLE_MODEL_FILTER),
"models": app.state.MODEL_FILTER_LIST, "models": config_get(app.state.MODEL_FILTER_LIST),
} }
@ -268,28 +272,28 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config( async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user) form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
): ):
app.state.ENABLE_MODEL_FILTER = form_data.enabled config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled)
app.state.MODEL_FILTER_LIST = form_data.models config_set(app.state.MODEL_FILTER_LIST, form_data.models)
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
return { return {
"enabled": app.state.ENABLE_MODEL_FILTER, "enabled": config_get(app.state.ENABLE_MODEL_FILTER),
"models": app.state.MODEL_FILTER_LIST, "models": config_get(app.state.MODEL_FILTER_LIST),
} }
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {
"url": app.state.WEBHOOK_URL, "url": config_get(app.state.WEBHOOK_URL),
} }
@ -299,12 +303,12 @@ class UrlForm(BaseModel):
@app.post("/api/webhook") @app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.WEBHOOK_URL = form_data.url config_set(app.state.WEBHOOK_URL, form_data.url)
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL)
return { return {
"url": app.state.WEBHOOK_URL, "url": config_get(app.state.WEBHOOK_URL),
} }