mirror of
https://github.com/open-webui/open-webui
synced 2025-04-16 21:42:50 +00:00
feat: save UI config changes to config.json
This commit is contained in:
parent
9a95767062
commit
058eb76568
@ -45,6 +45,8 @@ from config import (
|
|||||||
AUDIO_OPENAI_API_KEY,
|
AUDIO_OPENAI_API_KEY,
|
||||||
AUDIO_OPENAI_API_MODEL,
|
AUDIO_OPENAI_API_MODEL,
|
||||||
AUDIO_OPENAI_API_VOICE,
|
AUDIO_OPENAI_API_VOICE,
|
||||||
|
config_get,
|
||||||
|
config_set,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -83,10 +85,10 @@ class OpenAIConfigUpdateForm(BaseModel):
|
|||||||
@app.get("/config")
|
@app.get("/config")
|
||||||
async def get_openai_config(user=Depends(get_admin_user)):
|
async def get_openai_config(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
|
||||||
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
|
"OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL),
|
||||||
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
|
"OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -97,17 +99,22 @@ async def update_openai_config(
|
|||||||
if form_data.key == "":
|
if form_data.key == "":
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
app.state.OPENAI_API_BASE_URL = form_data.url
|
config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
|
||||||
app.state.OPENAI_API_KEY = form_data.key
|
config_set(app.state.OPENAI_API_KEY, form_data.key)
|
||||||
app.state.OPENAI_API_MODEL = form_data.model
|
config_set(app.state.OPENAI_API_MODEL, form_data.model)
|
||||||
app.state.OPENAI_API_VOICE = form_data.speaker
|
config_set(app.state.OPENAI_API_VOICE, form_data.speaker)
|
||||||
|
|
||||||
|
app.state.OPENAI_API_BASE_URL.save()
|
||||||
|
app.state.OPENAI_API_KEY.save()
|
||||||
|
app.state.OPENAI_API_MODEL.save()
|
||||||
|
app.state.OPENAI_API_VOICE.save()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
|
||||||
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
|
"OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL),
|
||||||
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
|
"OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,8 @@ from config import (
|
|||||||
IMAGE_GENERATION_MODEL,
|
IMAGE_GENERATION_MODEL,
|
||||||
IMAGE_SIZE,
|
IMAGE_SIZE,
|
||||||
IMAGE_STEPS,
|
IMAGE_STEPS,
|
||||||
|
config_get,
|
||||||
|
config_set,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +81,10 @@ app.state.IMAGE_STEPS = IMAGE_STEPS
|
|||||||
|
|
||||||
@app.get("/config")
|
@app.get("/config")
|
||||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
|
return {
|
||||||
|
"engine": config_get(app.state.ENGINE),
|
||||||
|
"enabled": config_get(app.state.ENABLED),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdateForm(BaseModel):
|
class ConfigUpdateForm(BaseModel):
|
||||||
@ -89,9 +94,12 @@ class ConfigUpdateForm(BaseModel):
|
|||||||
|
|
||||||
@app.post("/config/update")
|
@app.post("/config/update")
|
||||||
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
||||||
app.state.ENGINE = form_data.engine
|
config_set(app.state.ENGINE, form_data.engine)
|
||||||
app.state.ENABLED = form_data.enabled
|
config_set(app.state.ENABLED, form_data.enabled)
|
||||||
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
|
return {
|
||||||
|
"engine": config_get(app.state.ENGINE),
|
||||||
|
"enabled": config_get(app.state.ENABLED),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class EngineUrlUpdateForm(BaseModel):
|
class EngineUrlUpdateForm(BaseModel):
|
||||||
@ -102,8 +110,8 @@ class EngineUrlUpdateForm(BaseModel):
|
|||||||
@app.get("/url")
|
@app.get("/url")
|
||||||
async def get_engine_url(user=Depends(get_admin_user)):
|
async def get_engine_url(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
|
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
|
||||||
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
|
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -113,29 +121,29 @@ async def update_engine_url(
|
|||||||
):
|
):
|
||||||
|
|
||||||
if form_data.AUTOMATIC1111_BASE_URL == None:
|
if form_data.AUTOMATIC1111_BASE_URL == None:
|
||||||
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL))
|
||||||
else:
|
else:
|
||||||
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
|
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
|
||||||
try:
|
try:
|
||||||
r = requests.head(url)
|
r = requests.head(url)
|
||||||
app.state.AUTOMATIC1111_BASE_URL = url
|
config_set(app.state.AUTOMATIC1111_BASE_URL, url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
|
||||||
if form_data.COMFYUI_BASE_URL == None:
|
if form_data.COMFYUI_BASE_URL == None:
|
||||||
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL)
|
||||||
else:
|
else:
|
||||||
url = form_data.COMFYUI_BASE_URL.strip("/")
|
url = form_data.COMFYUI_BASE_URL.strip("/")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = requests.head(url)
|
r = requests.head(url)
|
||||||
app.state.COMFYUI_BASE_URL = url
|
config_set(app.state.COMFYUI_BASE_URL, url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
|
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
|
||||||
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
|
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
|
||||||
"status": True,
|
"status": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,8 +156,8 @@ class OpenAIConfigUpdateForm(BaseModel):
|
|||||||
@app.get("/openai/config")
|
@app.get("/openai/config")
|
||||||
async def get_openai_config(user=Depends(get_admin_user)):
|
async def get_openai_config(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -160,13 +168,13 @@ async def update_openai_config(
|
|||||||
if form_data.key == "":
|
if form_data.key == "":
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
app.state.OPENAI_API_BASE_URL = form_data.url
|
config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
|
||||||
app.state.OPENAI_API_KEY = form_data.key
|
config_set(app.state.OPENAI_API_KEY, form_data.key)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -176,7 +184,7 @@ class ImageSizeUpdateForm(BaseModel):
|
|||||||
|
|
||||||
@app.get("/size")
|
@app.get("/size")
|
||||||
async def get_image_size(user=Depends(get_admin_user)):
|
async def get_image_size(user=Depends(get_admin_user)):
|
||||||
return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
|
return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/size/update")
|
@app.post("/size/update")
|
||||||
@ -185,9 +193,9 @@ async def update_image_size(
|
|||||||
):
|
):
|
||||||
pattern = r"^\d+x\d+$" # Regular expression pattern
|
pattern = r"^\d+x\d+$" # Regular expression pattern
|
||||||
if re.match(pattern, form_data.size):
|
if re.match(pattern, form_data.size):
|
||||||
app.state.IMAGE_SIZE = form_data.size
|
config_set(app.state.IMAGE_SIZE, form_data.size)
|
||||||
return {
|
return {
|
||||||
"IMAGE_SIZE": app.state.IMAGE_SIZE,
|
"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE),
|
||||||
"status": True,
|
"status": True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@ -203,7 +211,7 @@ class ImageStepsUpdateForm(BaseModel):
|
|||||||
|
|
||||||
@app.get("/steps")
|
@app.get("/steps")
|
||||||
async def get_image_size(user=Depends(get_admin_user)):
|
async def get_image_size(user=Depends(get_admin_user)):
|
||||||
return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
|
return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/steps/update")
|
@app.post("/steps/update")
|
||||||
@ -211,9 +219,9 @@ async def update_image_size(
|
|||||||
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
|
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if form_data.steps >= 0:
|
if form_data.steps >= 0:
|
||||||
app.state.IMAGE_STEPS = form_data.steps
|
config_set(app.state.IMAGE_STEPS, form_data.steps)
|
||||||
return {
|
return {
|
||||||
"IMAGE_STEPS": app.state.IMAGE_STEPS,
|
"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS),
|
||||||
"status": True,
|
"status": True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@ -263,15 +271,25 @@ def get_models(user=Depends(get_current_user)):
|
|||||||
async def get_default_model(user=Depends(get_admin_user)):
|
async def get_default_model(user=Depends(get_admin_user)):
|
||||||
try:
|
try:
|
||||||
if app.state.ENGINE == "openai":
|
if app.state.ENGINE == "openai":
|
||||||
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
|
return {
|
||||||
|
"model": (
|
||||||
|
config_get(app.state.MODEL)
|
||||||
|
if config_get(app.state.MODEL)
|
||||||
|
else "dall-e-2"
|
||||||
|
)
|
||||||
|
}
|
||||||
elif app.state.ENGINE == "comfyui":
|
elif app.state.ENGINE == "comfyui":
|
||||||
return {"model": app.state.MODEL if app.state.MODEL else ""}
|
return {
|
||||||
|
"model": (
|
||||||
|
config_get(app.state.MODEL) if config_get(app.state.MODEL) else ""
|
||||||
|
)
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
||||||
options = r.json()
|
options = r.json()
|
||||||
return {"model": options["sd_model_checkpoint"]}
|
return {"model": options["sd_model_checkpoint"]}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
app.state.ENABLED = False
|
config_set(app.state.ENABLED, False)
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
|
||||||
|
|
||||||
@ -280,12 +298,9 @@ class UpdateModelForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def set_model_handler(model: str):
|
def set_model_handler(model: str):
|
||||||
if app.state.ENGINE == "openai":
|
if app.state.ENGINE in ["openai", "comfyui"]:
|
||||||
app.state.MODEL = model
|
config_set(app.state.MODEL, model)
|
||||||
return app.state.MODEL
|
return config_get(app.state.MODEL)
|
||||||
if app.state.ENGINE == "comfyui":
|
|
||||||
app.state.MODEL = model
|
|
||||||
return app.state.MODEL
|
|
||||||
else:
|
else:
|
||||||
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
||||||
options = r.json()
|
options = r.json()
|
||||||
@ -382,7 +397,7 @@ def generate_image(
|
|||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
|
width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x")))
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
@ -396,7 +411,11 @@ def generate_image(
|
|||||||
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
|
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
|
||||||
"prompt": form_data.prompt,
|
"prompt": form_data.prompt,
|
||||||
"n": form_data.n,
|
"n": form_data.n,
|
||||||
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
|
"size": (
|
||||||
|
form_data.size
|
||||||
|
if form_data.size
|
||||||
|
else config_get(app.state.IMAGE_SIZE)
|
||||||
|
),
|
||||||
"response_format": "b64_json",
|
"response_format": "b64_json",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,19 +449,19 @@ def generate_image(
|
|||||||
"n": form_data.n,
|
"n": form_data.n,
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.state.IMAGE_STEPS != None:
|
if config_get(app.state.IMAGE_STEPS) is not None:
|
||||||
data["steps"] = app.state.IMAGE_STEPS
|
data["steps"] = config_get(app.state.IMAGE_STEPS)
|
||||||
|
|
||||||
if form_data.negative_prompt != None:
|
if form_data.negative_prompt is not None:
|
||||||
data["negative_prompt"] = form_data.negative_prompt
|
data["negative_prompt"] = form_data.negative_prompt
|
||||||
|
|
||||||
data = ImageGenerationPayload(**data)
|
data = ImageGenerationPayload(**data)
|
||||||
|
|
||||||
res = comfyui_generate_image(
|
res = comfyui_generate_image(
|
||||||
app.state.MODEL,
|
config_get(app.state.MODEL),
|
||||||
data,
|
data,
|
||||||
user.id,
|
user.id,
|
||||||
app.state.COMFYUI_BASE_URL,
|
config_get(app.state.COMFYUI_BASE_URL),
|
||||||
)
|
)
|
||||||
log.debug(f"res: {res}")
|
log.debug(f"res: {res}")
|
||||||
|
|
||||||
@ -469,10 +488,10 @@ def generate_image(
|
|||||||
"height": height,
|
"height": height,
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.state.IMAGE_STEPS != None:
|
if config_get(app.state.IMAGE_STEPS) is not None:
|
||||||
data["steps"] = app.state.IMAGE_STEPS
|
data["steps"] = config_get(app.state.IMAGE_STEPS)
|
||||||
|
|
||||||
if form_data.negative_prompt != None:
|
if form_data.negative_prompt is not None:
|
||||||
data["negative_prompt"] = form_data.negative_prompt
|
data["negative_prompt"] = form_data.negative_prompt
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
|
@ -46,6 +46,8 @@ from config import (
|
|||||||
ENABLE_MODEL_FILTER,
|
ENABLE_MODEL_FILTER,
|
||||||
MODEL_FILTER_LIST,
|
MODEL_FILTER_LIST,
|
||||||
UPLOAD_DIR,
|
UPLOAD_DIR,
|
||||||
|
config_set,
|
||||||
|
config_get,
|
||||||
)
|
)
|
||||||
from utils.misc import calculate_sha256
|
from utils.misc import calculate_sha256
|
||||||
|
|
||||||
@ -96,7 +98,7 @@ async def get_status():
|
|||||||
|
|
||||||
@app.get("/urls")
|
@app.get("/urls")
|
||||||
async def get_ollama_api_urls(user=Depends(get_admin_user)):
|
async def get_ollama_api_urls(user=Depends(get_admin_user)):
|
||||||
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
|
return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)}
|
||||||
|
|
||||||
|
|
||||||
class UrlUpdateForm(BaseModel):
|
class UrlUpdateForm(BaseModel):
|
||||||
@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
|
|||||||
|
|
||||||
@app.post("/urls/update")
|
@app.post("/urls/update")
|
||||||
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
|
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
|
||||||
app.state.OLLAMA_BASE_URLS = form_data.urls
|
config_set(app.state.OLLAMA_BASE_URLS, form_data.urls)
|
||||||
|
|
||||||
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
|
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
|
||||||
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
|
return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/cancel/{request_id}")
|
@app.get("/cancel/{request_id}")
|
||||||
@ -153,7 +155,9 @@ def merge_models_lists(model_lists):
|
|||||||
|
|
||||||
async def get_all_models():
|
async def get_all_models():
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
|
tasks = [
|
||||||
|
fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS)
|
||||||
|
]
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
models = {
|
models = {
|
||||||
@ -179,14 +183,15 @@ async def get_ollama_tags(
|
|||||||
if user.role == "user":
|
if user.role == "user":
|
||||||
models["models"] = list(
|
models["models"] = list(
|
||||||
filter(
|
filter(
|
||||||
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
|
lambda model: model["name"]
|
||||||
|
in config_get(app.state.MODEL_FILTER_LIST),
|
||||||
models["models"],
|
models["models"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models
|
||||||
return models
|
return models
|
||||||
else:
|
else:
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
try:
|
try:
|
||||||
r = requests.request(method="GET", url=f"{url}/api/tags")
|
r = requests.request(method="GET", url=f"{url}/api/tags")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
@ -216,7 +221,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
|
|
||||||
# returns lowest version
|
# returns lowest version
|
||||||
tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
|
tasks = [
|
||||||
|
fetch_url(f"{url}/api/version")
|
||||||
|
for url in config_get(app.state.OLLAMA_BASE_URLS)
|
||||||
|
]
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
responses = list(filter(lambda x: x is not None, responses))
|
responses = list(filter(lambda x: x is not None, responses))
|
||||||
|
|
||||||
@ -235,7 +243,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|||||||
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
try:
|
try:
|
||||||
r = requests.request(method="GET", url=f"{url}/api/version")
|
r = requests.request(method="GET", url=f"{url}/api/version")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
@ -267,7 +275,7 @@ class ModelNameForm(BaseModel):
|
|||||||
async def pull_model(
|
async def pull_model(
|
||||||
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
|
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -355,7 +363,7 @@ async def push_model(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.debug(f"url: {url}")
|
log.debug(f"url: {url}")
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -417,7 +425,7 @@ async def create_model(
|
|||||||
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
|
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
log.debug(f"form_data: {form_data}")
|
log.debug(f"form_data: {form_data}")
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -490,7 +498,7 @@ async def copy_model(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -537,7 +545,7 @@ async def delete_model(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -577,7 +585,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
|
|||||||
)
|
)
|
||||||
|
|
||||||
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
|
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -634,7 +642,7 @@ async def generate_embeddings(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -684,7 +692,7 @@ def generate_ollama_embeddings(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -753,7 +761,7 @@ async def generate_completion(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -856,7 +864,7 @@ async def generate_chat_completion(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -965,7 +973,7 @@ async def generate_openai_chat_completion(
|
|||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||||
)
|
)
|
||||||
|
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -1064,7 +1072,7 @@ async def get_openai_models(
|
|||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
try:
|
try:
|
||||||
r = requests.request(method="GET", url=f"{url}/api/tags")
|
r = requests.request(method="GET", url=f"{url}/api/tags")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
@ -1198,7 +1206,7 @@ async def download_model(
|
|||||||
|
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
url_idx = 0
|
url_idx = 0
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
|
|
||||||
file_name = parse_huggingface_url(form_data.url)
|
file_name = parse_huggingface_url(form_data.url)
|
||||||
|
|
||||||
@ -1217,7 +1225,7 @@ async def download_model(
|
|||||||
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
url_idx = 0
|
url_idx = 0
|
||||||
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
|
ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
|
|
||||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||||
|
|
||||||
@ -1282,7 +1290,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
|||||||
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
|
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
|
||||||
# if url_idx == None:
|
# if url_idx == None:
|
||||||
# url_idx = 0
|
# url_idx = 0
|
||||||
# url = app.state.OLLAMA_BASE_URLS[url_idx]
|
# url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
|
||||||
|
|
||||||
# file_location = os.path.join(UPLOAD_DIR, file.filename)
|
# file_location = os.path.join(UPLOAD_DIR, file.filename)
|
||||||
# total_size = file.size
|
# total_size = file.size
|
||||||
@ -1319,7 +1327,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
|||||||
async def deprecated_proxy(
|
async def deprecated_proxy(
|
||||||
path: str, request: Request, user=Depends(get_verified_user)
|
path: str, request: Request, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
url = app.state.OLLAMA_BASE_URLS[0]
|
url = config_get(app.state.OLLAMA_BASE_URLS)[0]
|
||||||
target_url = f"{url}/{path}"
|
target_url = f"{url}/{path}"
|
||||||
|
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -26,6 +26,8 @@ from config import (
|
|||||||
CACHE_DIR,
|
CACHE_DIR,
|
||||||
ENABLE_MODEL_FILTER,
|
ENABLE_MODEL_FILTER,
|
||||||
MODEL_FILTER_LIST,
|
MODEL_FILTER_LIST,
|
||||||
|
config_set,
|
||||||
|
config_get,
|
||||||
)
|
)
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@ -75,32 +77,34 @@ class KeysUpdateForm(BaseModel):
|
|||||||
|
|
||||||
@app.get("/urls")
|
@app.get("/urls")
|
||||||
async def get_openai_urls(user=Depends(get_admin_user)):
|
async def get_openai_urls(user=Depends(get_admin_user)):
|
||||||
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
|
return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/urls/update")
|
@app.post("/urls/update")
|
||||||
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
|
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
|
||||||
await get_all_models()
|
await get_all_models()
|
||||||
app.state.OPENAI_API_BASE_URLS = form_data.urls
|
config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls)
|
||||||
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
|
return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/keys")
|
@app.get("/keys")
|
||||||
async def get_openai_keys(user=Depends(get_admin_user)):
|
async def get_openai_keys(user=Depends(get_admin_user)):
|
||||||
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
|
return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/keys/update")
|
@app.post("/keys/update")
|
||||||
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
|
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
|
||||||
app.state.OPENAI_API_KEYS = form_data.keys
|
config_set(app.state.OPENAI_API_KEYS, form_data.keys)
|
||||||
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
|
return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/audio/speech")
|
@app.post("/audio/speech")
|
||||||
async def speech(request: Request, user=Depends(get_verified_user)):
|
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
idx = None
|
idx = None
|
||||||
try:
|
try:
|
||||||
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
|
idx = config_get(app.state.OPENAI_API_BASE_URLS).index(
|
||||||
|
"https://api.openai.com/v1"
|
||||||
|
)
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
name = hashlib.sha256(body).hexdigest()
|
name = hashlib.sha256(body).hexdigest()
|
||||||
|
|
||||||
@ -114,13 +118,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||||||
return FileResponse(file_path)
|
return FileResponse(file_path)
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
|
headers["Authorization"] = (
|
||||||
|
f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}"
|
||||||
|
)
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
|
url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech",
|
||||||
data=body,
|
data=body,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
stream=True,
|
stream=True,
|
||||||
@ -180,7 +186,8 @@ def merge_models_lists(model_lists):
|
|||||||
[
|
[
|
||||||
{**model, "urlIdx": idx}
|
{**model, "urlIdx": idx}
|
||||||
for model in models
|
for model in models
|
||||||
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
|
if "api.openai.com"
|
||||||
|
not in config_get(app.state.OPENAI_API_BASE_URLS)[idx]
|
||||||
or "gpt" in model["id"]
|
or "gpt" in model["id"]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -191,12 +198,15 @@ def merge_models_lists(model_lists):
|
|||||||
async def get_all_models():
|
async def get_all_models():
|
||||||
log.info("get_all_models()")
|
log.info("get_all_models()")
|
||||||
|
|
||||||
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
|
if (
|
||||||
|
len(config_get(app.state.OPENAI_API_KEYS)) == 1
|
||||||
|
and config_get(app.state.OPENAI_API_KEYS)[0] == ""
|
||||||
|
):
|
||||||
models = {"data": []}
|
models = {"data": []}
|
||||||
else:
|
else:
|
||||||
tasks = [
|
tasks = [
|
||||||
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
|
fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx])
|
||||||
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
|
for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS))
|
||||||
]
|
]
|
||||||
|
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
@ -228,18 +238,19 @@ async def get_all_models():
|
|||||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
models = await get_all_models()
|
models = await get_all_models()
|
||||||
if app.state.ENABLE_MODEL_FILTER:
|
if config_get(app.state.ENABLE_MODEL_FILTER):
|
||||||
if user.role == "user":
|
if user.role == "user":
|
||||||
models["data"] = list(
|
models["data"] = list(
|
||||||
filter(
|
filter(
|
||||||
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
|
lambda model: model["id"]
|
||||||
|
in config_get(app.state.MODEL_FILTER_LIST),
|
||||||
models["data"],
|
models["data"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models
|
||||||
return models
|
return models
|
||||||
else:
|
else:
|
||||||
url = app.state.OPENAI_API_BASE_URLS[url_idx]
|
url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx]
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
|
|
||||||
@ -303,8 +314,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
log.error("Error loading request body into a dictionary:", e)
|
log.error("Error loading request body into a dictionary:", e)
|
||||||
|
|
||||||
url = app.state.OPENAI_API_BASE_URLS[idx]
|
url = config_get(app.state.OPENAI_API_BASE_URLS)[idx]
|
||||||
key = app.state.OPENAI_API_KEYS[idx]
|
key = config_get(app.state.OPENAI_API_KEYS)[idx]
|
||||||
|
|
||||||
target_url = f"{url}/{path}"
|
target_url = f"{url}/{path}"
|
||||||
|
|
||||||
|
@ -93,6 +93,8 @@ from config import (
|
|||||||
RAG_TEMPLATE,
|
RAG_TEMPLATE,
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
YOUTUBE_LOADER_LANGUAGE,
|
YOUTUBE_LOADER_LANGUAGE,
|
||||||
|
config_set,
|
||||||
|
config_get,
|
||||||
)
|
)
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
@ -133,7 +135,7 @@ def update_embedding_model(
|
|||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
update_model: bool = False,
|
update_model: bool = False,
|
||||||
):
|
):
|
||||||
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
|
if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "":
|
||||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||||
get_model_path(embedding_model, update_model),
|
get_model_path(embedding_model, update_model),
|
||||||
device=DEVICE_TYPE,
|
device=DEVICE_TYPE,
|
||||||
@ -158,22 +160,22 @@ def update_reranking_model(
|
|||||||
|
|
||||||
|
|
||||||
update_embedding_model(
|
update_embedding_model(
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
update_reranking_model(
|
update_reranking_model(
|
||||||
app.state.RAG_RERANKING_MODEL,
|
config_get(app.state.RAG_RERANKING_MODEL),
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
app.state.RAG_EMBEDDING_ENGINE,
|
config_get(app.state.RAG_EMBEDDING_ENGINE),
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.OPENAI_API_KEY,
|
config_get(app.state.OPENAI_API_KEY),
|
||||||
app.state.OPENAI_API_BASE_URL,
|
config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
)
|
)
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
@ -200,12 +202,12 @@ class UrlForm(CollectionNameForm):
|
|||||||
async def get_status():
|
async def get_status():
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"chunk_size": app.state.CHUNK_SIZE,
|
"chunk_size": config_get(app.state.CHUNK_SIZE),
|
||||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
|
||||||
"template": app.state.RAG_TEMPLATE,
|
"template": config_get(app.state.RAG_TEMPLATE),
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
"reranking_model": app.state.RAG_RERANKING_MODEL,
|
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -213,18 +215,21 @@ async def get_status():
|
|||||||
async def get_embedding_config(user=Depends(get_admin_user)):
|
async def get_embedding_config(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": app.state.OPENAI_API_BASE_URL,
|
"url": config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
"key": app.state.OPENAI_API_KEY,
|
"key": config_get(app.state.OPENAI_API_KEY),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/reranking")
|
@app.get("/reranking")
|
||||||
async def get_reraanking_config(user=Depends(get_admin_user)):
|
async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||||
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
|
return {
|
||||||
|
"status": True,
|
||||||
|
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConfigForm(BaseModel):
|
class OpenAIConfigForm(BaseModel):
|
||||||
@ -246,31 +251,31 @@ async def update_embedding_config(
|
|||||||
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine)
|
||||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model)
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]:
|
||||||
if form_data.openai_config != None:
|
if form_data.openai_config != None:
|
||||||
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url)
|
||||||
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key)
|
||||||
|
|
||||||
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True)
|
||||||
|
|
||||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
app.state.RAG_EMBEDDING_ENGINE,
|
config_get(app.state.RAG_EMBEDDING_ENGINE),
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.OPENAI_API_KEY,
|
config_get(app.state.OPENAI_API_KEY),
|
||||||
app.state.OPENAI_API_BASE_URL,
|
config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": app.state.OPENAI_API_BASE_URL,
|
"url": config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
"key": app.state.OPENAI_API_KEY,
|
"key": config_get(app.state.OPENAI_API_KEY),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -293,13 +298,13 @@ async def update_reranking_config(
|
|||||||
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model)
|
||||||
|
|
||||||
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
|
update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"reranking_model": app.state.RAG_RERANKING_MODEL,
|
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Problem updating reranking model: {e}")
|
log.exception(f"Problem updating reranking model: {e}")
|
||||||
@ -313,14 +318,16 @@ async def update_reranking_config(
|
|||||||
async def get_rag_config(user=Depends(get_admin_user)):
|
async def get_rag_config(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
|
"pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
|
||||||
"chunk": {
|
"chunk": {
|
||||||
"chunk_size": app.state.CHUNK_SIZE,
|
"chunk_size": config_get(app.state.CHUNK_SIZE),
|
||||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
|
||||||
},
|
},
|
||||||
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
"web_loader_ssl_verification": config_get(
|
||||||
|
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||||
|
),
|
||||||
"youtube": {
|
"youtube": {
|
||||||
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
|
"language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
|
||||||
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -345,50 +352,69 @@ class ConfigUpdateForm(BaseModel):
|
|||||||
|
|
||||||
@app.post("/config/update")
|
@app.post("/config/update")
|
||||||
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
||||||
app.state.PDF_EXTRACT_IMAGES = (
|
config_set(
|
||||||
form_data.pdf_extract_images
|
app.state.PDF_EXTRACT_IMAGES,
|
||||||
if form_data.pdf_extract_images != None
|
(
|
||||||
else app.state.PDF_EXTRACT_IMAGES
|
form_data.pdf_extract_images
|
||||||
|
if form_data.pdf_extract_images is not None
|
||||||
|
else config_get(app.state.PDF_EXTRACT_IMAGES)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.CHUNK_SIZE = (
|
config_set(
|
||||||
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
|
app.state.CHUNK_SIZE,
|
||||||
|
(
|
||||||
|
form_data.chunk.chunk_size
|
||||||
|
if form_data.chunk is not None
|
||||||
|
else config_get(app.state.CHUNK_SIZE)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.CHUNK_OVERLAP = (
|
config_set(
|
||||||
form_data.chunk.chunk_overlap
|
app.state.CHUNK_OVERLAP,
|
||||||
if form_data.chunk != None
|
(
|
||||||
else app.state.CHUNK_OVERLAP
|
form_data.chunk.chunk_overlap
|
||||||
|
if form_data.chunk is not None
|
||||||
|
else config_get(app.state.CHUNK_OVERLAP)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
config_set(
|
||||||
form_data.web_loader_ssl_verification
|
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
if form_data.web_loader_ssl_verification != None
|
(
|
||||||
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
form_data.web_loader_ssl_verification
|
||||||
|
if form_data.web_loader_ssl_verification != None
|
||||||
|
else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.YOUTUBE_LOADER_LANGUAGE = (
|
config_set(
|
||||||
form_data.youtube.language
|
app.state.YOUTUBE_LOADER_LANGUAGE,
|
||||||
if form_data.youtube != None
|
(
|
||||||
else app.state.YOUTUBE_LOADER_LANGUAGE
|
form_data.youtube.language
|
||||||
|
if form_data.youtube is not None
|
||||||
|
else config_get(app.state.YOUTUBE_LOADER_LANGUAGE)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.YOUTUBE_LOADER_TRANSLATION = (
|
app.state.YOUTUBE_LOADER_TRANSLATION = (
|
||||||
form_data.youtube.translation
|
form_data.youtube.translation
|
||||||
if form_data.youtube != None
|
if form_data.youtube is not None
|
||||||
else app.state.YOUTUBE_LOADER_TRANSLATION
|
else app.state.YOUTUBE_LOADER_TRANSLATION
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
|
"pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
|
||||||
"chunk": {
|
"chunk": {
|
||||||
"chunk_size": app.state.CHUNK_SIZE,
|
"chunk_size": config_get(app.state.CHUNK_SIZE),
|
||||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
|
||||||
},
|
},
|
||||||
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
"web_loader_ssl_verification": config_get(
|
||||||
|
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||||
|
),
|
||||||
"youtube": {
|
"youtube": {
|
||||||
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
|
"language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
|
||||||
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -398,7 +424,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
|||||||
async def get_rag_template(user=Depends(get_current_user)):
|
async def get_rag_template(user=Depends(get_current_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"template": app.state.RAG_TEMPLATE,
|
"template": config_get(app.state.RAG_TEMPLATE),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -406,10 +432,10 @@ async def get_rag_template(user=Depends(get_current_user)):
|
|||||||
async def get_query_settings(user=Depends(get_admin_user)):
|
async def get_query_settings(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"template": app.state.RAG_TEMPLATE,
|
"template": config_get(app.state.RAG_TEMPLATE),
|
||||||
"k": app.state.TOP_K,
|
"k": config_get(app.state.TOP_K),
|
||||||
"r": app.state.RELEVANCE_THRESHOLD,
|
"r": config_get(app.state.RELEVANCE_THRESHOLD),
|
||||||
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
|
"hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -424,16 +450,22 @@ class QuerySettingsForm(BaseModel):
|
|||||||
async def update_query_settings(
|
async def update_query_settings(
|
||||||
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
config_set(
|
||||||
app.state.TOP_K = form_data.k if form_data.k else 4
|
app.state.RAG_TEMPLATE,
|
||||||
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
form_data.template if form_data.template else RAG_TEMPLATE,
|
||||||
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
|
)
|
||||||
|
config_set(app.state.TOP_K, form_data.k if form_data.k else 4)
|
||||||
|
config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0)
|
||||||
|
config_set(
|
||||||
|
app.state.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
|
form_data.hybrid if form_data.hybrid else False,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"template": app.state.RAG_TEMPLATE,
|
"template": config_get(app.state.RAG_TEMPLATE),
|
||||||
"k": app.state.TOP_K,
|
"k": config_get(app.state.TOP_K),
|
||||||
"r": app.state.RELEVANCE_THRESHOLD,
|
"r": config_get(app.state.RELEVANCE_THRESHOLD),
|
||||||
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
|
"hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -451,21 +483,25 @@ def query_doc_handler(
|
|||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
|
||||||
return query_doc_with_hybrid_search(
|
return query_doc_with_hybrid_search(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
|
||||||
reranking_function=app.state.sentence_transformer_rf,
|
reranking_function=app.state.sentence_transformer_rf,
|
||||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
r=(
|
||||||
|
form_data.r
|
||||||
|
if form_data.r
|
||||||
|
else config_get(app.state.RELEVANCE_THRESHOLD)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_doc(
|
return query_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
@ -489,21 +525,25 @@ def query_collection_handler(
|
|||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
|
||||||
return query_collection_with_hybrid_search(
|
return query_collection_with_hybrid_search(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
|
||||||
reranking_function=app.state.sentence_transformer_rf,
|
reranking_function=app.state.sentence_transformer_rf,
|
||||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
r=(
|
||||||
|
form_data.r
|
||||||
|
if form_data.r
|
||||||
|
else config_get(app.state.RELEVANCE_THRESHOLD)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_collection(
|
return query_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -520,8 +560,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
|
|||||||
loader = YoutubeLoader.from_youtube_url(
|
loader = YoutubeLoader.from_youtube_url(
|
||||||
form_data.url,
|
form_data.url,
|
||||||
add_video_info=True,
|
add_video_info=True,
|
||||||
language=app.state.YOUTUBE_LOADER_LANGUAGE,
|
language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
|
||||||
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
|
translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION),
|
||||||
)
|
)
|
||||||
data = loader.load()
|
data = loader.load()
|
||||||
|
|
||||||
@ -548,7 +588,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
|||||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||||
try:
|
try:
|
||||||
loader = get_web_loader(
|
loader = get_web_loader(
|
||||||
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
form_data.url,
|
||||||
|
verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION),
|
||||||
)
|
)
|
||||||
data = loader.load()
|
data = loader.load()
|
||||||
|
|
||||||
@ -604,8 +645,8 @@ def resolve_hostname(hostname):
|
|||||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=app.state.CHUNK_SIZE,
|
chunk_size=config_get(app.state.CHUNK_SIZE),
|
||||||
chunk_overlap=app.state.CHUNK_OVERLAP,
|
chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
|
||||||
add_start_index=True,
|
add_start_index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -622,8 +663,8 @@ def store_text_in_vector_db(
|
|||||||
text, metadata, collection_name, overwrite: bool = False
|
text, metadata, collection_name, overwrite: bool = False
|
||||||
) -> bool:
|
) -> bool:
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=app.state.CHUNK_SIZE,
|
chunk_size=config_get(app.state.CHUNK_SIZE),
|
||||||
chunk_overlap=app.state.CHUNK_OVERLAP,
|
chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
|
||||||
add_start_index=True,
|
add_start_index=True,
|
||||||
)
|
)
|
||||||
docs = text_splitter.create_documents([text], metadatas=[metadata])
|
docs = text_splitter.create_documents([text], metadatas=[metadata])
|
||||||
@ -646,11 +687,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||||
|
|
||||||
embedding_func = get_embedding_function(
|
embedding_func = get_embedding_function(
|
||||||
app.state.RAG_EMBEDDING_ENGINE,
|
config_get(app.state.RAG_EMBEDDING_ENGINE),
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
config_get(app.state.RAG_EMBEDDING_MODEL),
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.OPENAI_API_KEY,
|
config_get(app.state.OPENAI_API_KEY),
|
||||||
app.state.OPENAI_API_BASE_URL,
|
config_get(app.state.OPENAI_API_BASE_URL),
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||||
@ -724,7 +765,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if file_ext == "pdf":
|
if file_ext == "pdf":
|
||||||
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
|
loader = PyPDFLoader(
|
||||||
|
file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES)
|
||||||
|
)
|
||||||
elif file_ext == "csv":
|
elif file_ext == "csv":
|
||||||
loader = CSVLoader(file_path)
|
loader = CSVLoader(file_path)
|
||||||
elif file_ext == "rst":
|
elif file_ext == "rst":
|
||||||
|
@ -21,6 +21,8 @@ from config import (
|
|||||||
USER_PERMISSIONS,
|
USER_PERMISSIONS,
|
||||||
WEBHOOK_URL,
|
WEBHOOK_URL,
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||||
|
JWT_EXPIRES_IN,
|
||||||
|
config_get,
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -28,7 +30,7 @@ app = FastAPI()
|
|||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
|
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||||
app.state.JWT_EXPIRES_IN = "-1"
|
app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||||
|
|
||||||
app.state.DEFAULT_MODELS = DEFAULT_MODELS
|
app.state.DEFAULT_MODELS = DEFAULT_MODELS
|
||||||
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||||
@ -61,6 +63,6 @@ async def get_status():
|
|||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"auth": WEBUI_AUTH,
|
"auth": WEBUI_AUTH,
|
||||||
"default_models": app.state.DEFAULT_MODELS,
|
"default_models": config_get(app.state.DEFAULT_MODELS),
|
||||||
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,
|
"default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS),
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ from utils.utils import (
|
|||||||
from utils.misc import parse_duration, validate_email_format
|
from utils.misc import parse_duration, validate_email_format
|
||||||
from utils.webhook import post_webhook
|
from utils.webhook import post_webhook
|
||||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||||
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
|
|||||||
if user:
|
if user:
|
||||||
token = create_token(
|
token = create_token(
|
||||||
data={"id": user.id},
|
data={"id": user.id},
|
||||||
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
|
expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
|
|||||||
|
|
||||||
@router.post("/signup", response_model=SigninResponse)
|
@router.post("/signup", response_model=SigninResponse)
|
||||||
async def signup(request: Request, form_data: SignupForm):
|
async def signup(request: Request, form_data: SignupForm):
|
||||||
if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH:
|
if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
|
|||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin"
|
||||||
if Users.get_num_users() == 0
|
if Users.get_num_users() == 0
|
||||||
else request.app.state.DEFAULT_USER_ROLE
|
else config_get(request.app.state.DEFAULT_USER_ROLE)
|
||||||
)
|
)
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
@ -194,13 +194,15 @@ async def signup(request: Request, form_data: SignupForm):
|
|||||||
if user:
|
if user:
|
||||||
token = create_token(
|
token = create_token(
|
||||||
data={"id": user.id},
|
data={"id": user.id},
|
||||||
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
|
expires_delta=parse_duration(
|
||||||
|
config_get(request.app.state.JWT_EXPIRES_IN)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# response.set_cookie(key='token', value=token, httponly=True)
|
# response.set_cookie(key='token', value=token, httponly=True)
|
||||||
|
|
||||||
if request.app.state.WEBHOOK_URL:
|
if config_get(request.app.state.WEBHOOK_URL):
|
||||||
post_webhook(
|
post_webhook(
|
||||||
request.app.state.WEBHOOK_URL,
|
config_get(request.app.state.WEBHOOK_URL),
|
||||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
{
|
{
|
||||||
"action": "signup",
|
"action": "signup",
|
||||||
@ -276,13 +278,15 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
|||||||
|
|
||||||
@router.get("/signup/enabled", response_model=bool)
|
@router.get("/signup/enabled", response_model=bool)
|
||||||
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
|
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
|
||||||
return request.app.state.ENABLE_SIGNUP
|
return config_get(request.app.state.ENABLE_SIGNUP)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/signup/enabled/toggle", response_model=bool)
|
@router.get("/signup/enabled/toggle", response_model=bool)
|
||||||
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
|
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
|
||||||
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
|
config_set(
|
||||||
return request.app.state.ENABLE_SIGNUP
|
request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP)
|
||||||
|
)
|
||||||
|
return config_get(request.app.state.ENABLE_SIGNUP)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
@ -292,7 +296,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
|
|||||||
|
|
||||||
@router.get("/signup/user/role")
|
@router.get("/signup/user/role")
|
||||||
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
|
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
|
||||||
return request.app.state.DEFAULT_USER_ROLE
|
return config_get(request.app.state.DEFAULT_USER_ROLE)
|
||||||
|
|
||||||
|
|
||||||
class UpdateRoleForm(BaseModel):
|
class UpdateRoleForm(BaseModel):
|
||||||
@ -304,8 +308,8 @@ async def update_default_user_role(
|
|||||||
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
|
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if form_data.role in ["pending", "user", "admin"]:
|
if form_data.role in ["pending", "user", "admin"]:
|
||||||
request.app.state.DEFAULT_USER_ROLE = form_data.role
|
config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role)
|
||||||
return request.app.state.DEFAULT_USER_ROLE
|
return config_get(request.app.state.DEFAULT_USER_ROLE)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
@ -315,7 +319,7 @@ async def update_default_user_role(
|
|||||||
|
|
||||||
@router.get("/token/expires")
|
@router.get("/token/expires")
|
||||||
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
|
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
|
||||||
return request.app.state.JWT_EXPIRES_IN
|
return config_get(request.app.state.JWT_EXPIRES_IN)
|
||||||
|
|
||||||
|
|
||||||
class UpdateJWTExpiresDurationForm(BaseModel):
|
class UpdateJWTExpiresDurationForm(BaseModel):
|
||||||
@ -332,10 +336,10 @@ async def update_token_expires_duration(
|
|||||||
|
|
||||||
# Check if the input string matches the pattern
|
# Check if the input string matches the pattern
|
||||||
if re.match(pattern, form_data.duration):
|
if re.match(pattern, form_data.duration):
|
||||||
request.app.state.JWT_EXPIRES_IN = form_data.duration
|
config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration)
|
||||||
return request.app.state.JWT_EXPIRES_IN
|
return config_get(request.app.state.JWT_EXPIRES_IN)
|
||||||
else:
|
else:
|
||||||
return request.app.state.JWT_EXPIRES_IN
|
return config_get(request.app.state.JWT_EXPIRES_IN)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
@ -9,6 +9,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
|
from config import config_set, config_get
|
||||||
|
|
||||||
from utils.utils import (
|
from utils.utils import (
|
||||||
get_password_hash,
|
get_password_hash,
|
||||||
@ -44,8 +45,8 @@ class SetDefaultSuggestionsForm(BaseModel):
|
|||||||
async def set_global_default_models(
|
async def set_global_default_models(
|
||||||
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
|
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
request.app.state.DEFAULT_MODELS = form_data.models
|
config_set(request.app.state.DEFAULT_MODELS, form_data.models)
|
||||||
return request.app.state.DEFAULT_MODELS
|
return config_get(request.app.state.DEFAULT_MODELS)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
|
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
|
||||||
@ -55,5 +56,5 @@ async def set_global_default_suggestions(
|
|||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
data = form_data.model_dump()
|
data = form_data.model_dump()
|
||||||
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
|
config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"])
|
||||||
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
|
return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS)
|
||||||
|
@ -15,7 +15,7 @@ from apps.web.models.auths import Auths
|
|||||||
from utils.utils import get_current_user, get_password_hash, get_admin_user
|
from utils.utils import get_current_user, get_password_hash, get_admin_user
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS
|
from config import SRC_LOG_LEVELS, config_set, config_get
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
|
|||||||
|
|
||||||
@router.get("/permissions/user")
|
@router.get("/permissions/user")
|
||||||
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||||
return request.app.state.USER_PERMISSIONS
|
return config_get(request.app.state.USER_PERMISSIONS)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/permissions/user")
|
@router.post("/permissions/user")
|
||||||
async def update_user_permissions(
|
async def update_user_permissions(
|
||||||
request: Request, form_data: dict, user=Depends(get_admin_user)
|
request: Request, form_data: dict, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
request.app.state.USER_PERMISSIONS = form_data
|
config_set(request.app.state.USER_PERMISSIONS, form_data)
|
||||||
return request.app.state.USER_PERMISSIONS
|
return config_get(request.app.state.USER_PERMISSIONS)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
@ -5,6 +5,7 @@ import chromadb
|
|||||||
from chromadb import Settings
|
from chromadb import Settings
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from typing import TypeVar, Generic, Union
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
@ -17,7 +18,6 @@ import shutil
|
|||||||
from secrets import token_bytes
|
from secrets import token_bytes
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Load .env file
|
# Load .env file
|
||||||
####################################
|
####################################
|
||||||
@ -29,7 +29,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
print("dotenv not installed, skipping...")
|
print("dotenv not installed, skipping...")
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# LOGGING
|
# LOGGING
|
||||||
####################################
|
####################################
|
||||||
@ -71,7 +70,6 @@ for source in log_sources:
|
|||||||
|
|
||||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||||
|
|
||||||
|
|
||||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||||
if WEBUI_NAME != "Open WebUI":
|
if WEBUI_NAME != "Open WebUI":
|
||||||
WEBUI_NAME += " (Open WebUI)"
|
WEBUI_NAME += " (Open WebUI)"
|
||||||
@ -80,7 +78,6 @@ WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
|
|||||||
|
|
||||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# ENV (dev,test,prod)
|
# ENV (dev,test,prod)
|
||||||
####################################
|
####################################
|
||||||
@ -151,26 +148,14 @@ for version in soup.find_all("h2"):
|
|||||||
|
|
||||||
changelog_json[version_number] = version_data
|
changelog_json[version_number] = version_data
|
||||||
|
|
||||||
|
|
||||||
CHANGELOG = changelog_json
|
CHANGELOG = changelog_json
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_VERSION
|
# WEBUI_VERSION
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
|
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
|
||||||
|
|
||||||
####################################
|
|
||||||
# WEBUI_AUTH (Required for security)
|
|
||||||
####################################
|
|
||||||
|
|
||||||
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
|
||||||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# DATA/FRONTEND BUILD DIR
|
# DATA/FRONTEND BUILD DIR
|
||||||
####################################
|
####################################
|
||||||
@ -184,6 +169,93 @@ try:
|
|||||||
except:
|
except:
|
||||||
CONFIG_DATA = {}
|
CONFIG_DATA = {}
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Config helpers
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
def save_config():
|
||||||
|
try:
|
||||||
|
with open(f"{DATA_DIR}/config.json", "w") as f:
|
||||||
|
json.dump(CONFIG_DATA, f, indent="\t")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_value(config_path: str):
|
||||||
|
path_parts = config_path.split(".")
|
||||||
|
cur_config = CONFIG_DATA
|
||||||
|
for key in path_parts:
|
||||||
|
if key in cur_config:
|
||||||
|
cur_config = cur_config[key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return cur_config
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedConfig(Generic[T]):
|
||||||
|
def __init__(self, env_name: str, config_path: str, env_value: T):
|
||||||
|
self.env_name = env_name
|
||||||
|
self.config_path = config_path
|
||||||
|
self.env_value = env_value
|
||||||
|
self.config_value = get_config_value(config_path)
|
||||||
|
if self.config_value is not None:
|
||||||
|
log.info(f"'{env_name}' loaded from config.json")
|
||||||
|
self.value = self.config_value
|
||||||
|
else:
|
||||||
|
self.value = env_value
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
# Don't save if the value is the same as the env value and the config value
|
||||||
|
if self.env_value == self.value:
|
||||||
|
if self.config_value == self.value:
|
||||||
|
return
|
||||||
|
log.info(f"Saving '{self.env_name}' to config.json")
|
||||||
|
path_parts = self.config_path.split(".")
|
||||||
|
config = CONFIG_DATA
|
||||||
|
for key in path_parts[:-1]:
|
||||||
|
if key not in config:
|
||||||
|
config[key] = {}
|
||||||
|
config = config[key]
|
||||||
|
config[path_parts[-1]] = self.value
|
||||||
|
save_config()
|
||||||
|
self.config_value = self.value
|
||||||
|
|
||||||
|
|
||||||
|
def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True):
|
||||||
|
if isinstance(config, WrappedConfig):
|
||||||
|
config.value = value
|
||||||
|
if save_config:
|
||||||
|
config.save()
|
||||||
|
else:
|
||||||
|
config = value
|
||||||
|
|
||||||
|
|
||||||
|
def config_get(config: Union[WrappedConfig[T], T]) -> T:
|
||||||
|
if isinstance(config, WrappedConfig):
|
||||||
|
return config.value
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# WEBUI_AUTH (Required for security)
|
||||||
|
####################################
|
||||||
|
|
||||||
|
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
||||||
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||||
|
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||||
|
)
|
||||||
|
JWT_EXPIRES_IN = WrappedConfig(
|
||||||
|
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Static DIR
|
# Static DIR
|
||||||
####################################
|
####################################
|
||||||
@ -225,7 +297,6 @@ if CUSTOM_NAME:
|
|||||||
log.exception(e)
|
log.exception(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# File Upload DIR
|
# File Upload DIR
|
||||||
####################################
|
####################################
|
||||||
@ -233,7 +304,6 @@ if CUSTOM_NAME:
|
|||||||
UPLOAD_DIR = f"{DATA_DIR}/uploads"
|
UPLOAD_DIR = f"{DATA_DIR}/uploads"
|
||||||
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Cache DIR
|
# Cache DIR
|
||||||
####################################
|
####################################
|
||||||
@ -241,7 +311,6 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
|||||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
CACHE_DIR = f"{DATA_DIR}/cache"
|
||||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Docs DIR
|
# Docs DIR
|
||||||
####################################
|
####################################
|
||||||
@ -282,7 +351,6 @@ if not os.path.exists(LITELLM_CONFIG_PATH):
|
|||||||
create_config_file(LITELLM_CONFIG_PATH)
|
create_config_file(LITELLM_CONFIG_PATH)
|
||||||
log.info("Config file created successfully.")
|
log.info("Config file created successfully.")
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# OLLAMA_BASE_URL
|
# OLLAMA_BASE_URL
|
||||||
####################################
|
####################################
|
||||||
@ -313,12 +381,13 @@ if ENV == "prod":
|
|||||||
elif K8S_FLAG:
|
elif K8S_FLAG:
|
||||||
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
|
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
|
||||||
|
|
||||||
|
|
||||||
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
|
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
|
||||||
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
|
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
|
||||||
|
|
||||||
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
|
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
|
||||||
|
OLLAMA_BASE_URLS = WrappedConfig(
|
||||||
|
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# OPENAI_API
|
# OPENAI_API
|
||||||
@ -327,7 +396,6 @@ OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
|
|||||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||||
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
||||||
|
|
||||||
|
|
||||||
if OPENAI_API_BASE_URL == "":
|
if OPENAI_API_BASE_URL == "":
|
||||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
||||||
@ -335,7 +403,7 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
|
|||||||
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
|
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
|
||||||
|
|
||||||
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
|
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
|
||||||
|
OPENAI_API_KEYS = WrappedConfig("OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS)
|
||||||
|
|
||||||
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
|
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
|
||||||
OPENAI_API_BASE_URLS = (
|
OPENAI_API_BASE_URLS = (
|
||||||
@ -346,37 +414,42 @@ OPENAI_API_BASE_URLS = [
|
|||||||
url.strip() if url != "" else "https://api.openai.com/v1"
|
url.strip() if url != "" else "https://api.openai.com/v1"
|
||||||
for url in OPENAI_API_BASE_URLS.split(";")
|
for url in OPENAI_API_BASE_URLS.split(";")
|
||||||
]
|
]
|
||||||
|
OPENAI_API_BASE_URLS = WrappedConfig(
|
||||||
|
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
|
||||||
|
)
|
||||||
|
|
||||||
OPENAI_API_KEY = ""
|
OPENAI_API_KEY = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OPENAI_API_KEY = OPENAI_API_KEYS[
|
OPENAI_API_KEY = OPENAI_API_KEYS.value[
|
||||||
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
|
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
|
||||||
]
|
]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI
|
# WEBUI
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
ENABLE_SIGNUP = (
|
ENABLE_SIGNUP = WrappedConfig(
|
||||||
False
|
"ENABLE_SIGNUP",
|
||||||
if WEBUI_AUTH == False
|
"ui.enable_signup",
|
||||||
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
|
(
|
||||||
|
False
|
||||||
|
if not WEBUI_AUTH
|
||||||
|
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
DEFAULT_MODELS = WrappedConfig(
|
||||||
|
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
|
||||||
)
|
)
|
||||||
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
|
|
||||||
|
|
||||||
|
DEFAULT_PROMPT_SUGGESTIONS = WrappedConfig(
|
||||||
DEFAULT_PROMPT_SUGGESTIONS = (
|
"DEFAULT_PROMPT_SUGGESTIONS",
|
||||||
CONFIG_DATA["ui"]["prompt_suggestions"]
|
"ui.prompt_suggestions",
|
||||||
if "ui" in CONFIG_DATA
|
[
|
||||||
and "prompt_suggestions" in CONFIG_DATA["ui"]
|
|
||||||
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
|
|
||||||
else [
|
|
||||||
{
|
{
|
||||||
"title": ["Help me study", "vocabulary for a college entrance exam"],
|
"title": ["Help me study", "vocabulary for a college entrance exam"],
|
||||||
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
|
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
|
||||||
@ -404,23 +477,42 @@ DEFAULT_PROMPT_SUGGESTIONS = (
|
|||||||
"title": ["Overcome procrastination", "give me tips"],
|
"title": ["Overcome procrastination", "give me tips"],
|
||||||
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
|
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_USER_ROLE = WrappedConfig(
|
||||||
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
|
"DEFAULT_USER_ROLE",
|
||||||
|
"ui.default_user_role",
|
||||||
USER_PERMISSIONS_CHAT_DELETION = (
|
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
||||||
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
|
USER_PERMISSIONS_CHAT_DELETION = WrappedConfig(
|
||||||
|
"USER_PERMISSIONS_CHAT_DELETION",
|
||||||
|
"ui.user_permissions.chat.deletion",
|
||||||
|
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
|
USER_PERMISSIONS = WrappedConfig(
|
||||||
|
"USER_PERMISSIONS",
|
||||||
|
"ui.user_permissions",
|
||||||
|
{"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
|
||||||
|
)
|
||||||
|
|
||||||
|
ENABLE_MODEL_FILTER = WrappedConfig(
|
||||||
|
"ENABLE_MODEL_FILTER",
|
||||||
|
"model_filter.enable",
|
||||||
|
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
|
||||||
|
)
|
||||||
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
||||||
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
|
MODEL_FILTER_LIST = WrappedConfig(
|
||||||
|
"MODEL_FILTER_LIST",
|
||||||
|
"model_filter.list",
|
||||||
|
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
|
||||||
|
)
|
||||||
|
|
||||||
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
|
WEBHOOK_URL = WrappedConfig(
|
||||||
|
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
|
||||||
|
)
|
||||||
|
|
||||||
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
||||||
|
|
||||||
@ -458,26 +550,45 @@ else:
|
|||||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||||
|
|
||||||
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
|
RAG_TOP_K = WrappedConfig(
|
||||||
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
|
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
|
||||||
|
)
|
||||||
ENABLE_RAG_HYBRID_SEARCH = (
|
RAG_RELEVANCE_THRESHOLD = WrappedConfig(
|
||||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
|
"RAG_RELEVANCE_THRESHOLD",
|
||||||
|
"rag.relevance_threshold",
|
||||||
|
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENABLE_RAG_HYBRID_SEARCH = WrappedConfig(
|
||||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
"ENABLE_RAG_HYBRID_SEARCH",
|
||||||
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
|
"rag.enable_hybrid_search",
|
||||||
|
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = WrappedConfig(
|
||||||
|
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
|
||||||
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
|
"rag.enable_web_loader_ssl_verification",
|
||||||
|
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
|
||||||
RAG_EMBEDDING_MODEL = os.environ.get(
|
|
||||||
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
)
|
)
|
||||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
|
||||||
|
RAG_EMBEDDING_ENGINE = WrappedConfig(
|
||||||
|
"RAG_EMBEDDING_ENGINE",
|
||||||
|
"rag.embedding_engine",
|
||||||
|
os.environ.get("RAG_EMBEDDING_ENGINE", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
PDF_EXTRACT_IMAGES = WrappedConfig(
|
||||||
|
"PDF_EXTRACT_IMAGES",
|
||||||
|
"rag.pdf_extract_images",
|
||||||
|
os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
RAG_EMBEDDING_MODEL = WrappedConfig(
|
||||||
|
"RAG_EMBEDDING_MODEL",
|
||||||
|
"rag.embedding_model",
|
||||||
|
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
|
||||||
|
)
|
||||||
|
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
|
||||||
|
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||||
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
@ -487,9 +598,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
|||||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
|
RAG_RERANKING_MODEL = WrappedConfig(
|
||||||
if not RAG_RERANKING_MODEL == "":
|
"RAG_RERANKING_MODEL",
|
||||||
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
|
"rag.reranking_model",
|
||||||
|
os.environ.get("RAG_RERANKING_MODEL", ""),
|
||||||
|
)
|
||||||
|
if RAG_RERANKING_MODEL.value != "":
|
||||||
|
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
|
||||||
|
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||||
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
@ -499,7 +614,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
|||||||
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if CHROMA_HTTP_HOST != "":
|
if CHROMA_HTTP_HOST != "":
|
||||||
CHROMA_CLIENT = chromadb.HttpClient(
|
CHROMA_CLIENT = chromadb.HttpClient(
|
||||||
host=CHROMA_HTTP_HOST,
|
host=CHROMA_HTTP_HOST,
|
||||||
@ -518,7 +632,6 @@ else:
|
|||||||
database=CHROMA_DATABASE,
|
database=CHROMA_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||||
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
||||||
|
|
||||||
@ -527,9 +640,14 @@ if USE_CUDA.lower() == "true":
|
|||||||
else:
|
else:
|
||||||
DEVICE_TYPE = "cpu"
|
DEVICE_TYPE = "cpu"
|
||||||
|
|
||||||
|
CHUNK_SIZE = WrappedConfig(
|
||||||
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
|
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
|
||||||
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
|
)
|
||||||
|
CHUNK_OVERLAP = WrappedConfig(
|
||||||
|
"CHUNK_OVERLAP",
|
||||||
|
"rag.chunk_overlap",
|
||||||
|
int(os.environ.get("CHUNK_OVERLAP", "100")),
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||||
<context>
|
<context>
|
||||||
@ -545,16 +663,32 @@ And answer according to the language of the user's question.
|
|||||||
Given the context information, answer the query.
|
Given the context information, answer the query.
|
||||||
Query: [query]"""
|
Query: [query]"""
|
||||||
|
|
||||||
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
|
RAG_TEMPLATE = WrappedConfig(
|
||||||
|
"RAG_TEMPLATE",
|
||||||
|
"rag.template",
|
||||||
|
os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
|
||||||
|
)
|
||||||
|
|
||||||
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
RAG_OPENAI_API_BASE_URL = WrappedConfig(
|
||||||
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
|
"RAG_OPENAI_API_BASE_URL",
|
||||||
|
"rag.openai_api_base_url",
|
||||||
|
os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||||
|
)
|
||||||
|
RAG_OPENAI_API_KEY = WrappedConfig(
|
||||||
|
"RAG_OPENAI_API_KEY",
|
||||||
|
"rag.openai_api_key",
|
||||||
|
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||||
|
)
|
||||||
|
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH = (
|
ENABLE_RAG_LOCAL_WEB_FETCH = (
|
||||||
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",")
|
YOUTUBE_LOADER_LANGUAGE = WrappedConfig(
|
||||||
|
"YOUTUBE_LOADER_LANGUAGE",
|
||||||
|
"rag.youtube_loader_language",
|
||||||
|
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Transcribe
|
# Transcribe
|
||||||
@ -566,39 +700,82 @@ WHISPER_MODEL_AUTO_UPDATE = (
|
|||||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Images
|
# Images
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
|
IMAGE_GENERATION_ENGINE = WrappedConfig(
|
||||||
|
"IMAGE_GENERATION_ENGINE",
|
||||||
ENABLE_IMAGE_GENERATION = (
|
"image_generation.engine",
|
||||||
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
|
os.getenv("IMAGE_GENERATION_ENGINE", ""),
|
||||||
)
|
)
|
||||||
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
|
|
||||||
|
|
||||||
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
|
ENABLE_IMAGE_GENERATION = WrappedConfig(
|
||||||
|
"ENABLE_IMAGE_GENERATION",
|
||||||
IMAGES_OPENAI_API_BASE_URL = os.getenv(
|
"image_generation.enable",
|
||||||
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
|
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
|
||||||
|
)
|
||||||
|
AUTOMATIC1111_BASE_URL = WrappedConfig(
|
||||||
|
"AUTOMATIC1111_BASE_URL",
|
||||||
|
"image_generation.automatic1111.base_url",
|
||||||
|
os.getenv("AUTOMATIC1111_BASE_URL", ""),
|
||||||
)
|
)
|
||||||
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
|
|
||||||
|
|
||||||
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
|
COMFYUI_BASE_URL = WrappedConfig(
|
||||||
|
"COMFYUI_BASE_URL",
|
||||||
|
"image_generation.comfyui.base_url",
|
||||||
|
os.getenv("COMFYUI_BASE_URL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
|
IMAGES_OPENAI_API_BASE_URL = WrappedConfig(
|
||||||
|
"IMAGES_OPENAI_API_BASE_URL",
|
||||||
|
"image_generation.openai.api_base_url",
|
||||||
|
os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||||
|
)
|
||||||
|
IMAGES_OPENAI_API_KEY = WrappedConfig(
|
||||||
|
"IMAGES_OPENAI_API_KEY",
|
||||||
|
"image_generation.openai.api_key",
|
||||||
|
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||||
|
)
|
||||||
|
|
||||||
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
|
IMAGE_SIZE = WrappedConfig(
|
||||||
|
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAGE_STEPS = WrappedConfig(
|
||||||
|
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAGE_GENERATION_MODEL = WrappedConfig(
|
||||||
|
"IMAGE_GENERATION_MODEL",
|
||||||
|
"image_generation.model",
|
||||||
|
os.getenv("IMAGE_GENERATION_MODEL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Audio
|
# Audio
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
AUDIO_OPENAI_API_BASE_URL = WrappedConfig(
|
||||||
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
|
"AUDIO_OPENAI_API_BASE_URL",
|
||||||
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1")
|
"audio.openai.api_base_url",
|
||||||
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy")
|
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||||
|
)
|
||||||
|
AUDIO_OPENAI_API_KEY = WrappedConfig(
|
||||||
|
"AUDIO_OPENAI_API_KEY",
|
||||||
|
"audio.openai.api_key",
|
||||||
|
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||||
|
)
|
||||||
|
AUDIO_OPENAI_API_MODEL = WrappedConfig(
|
||||||
|
"AUDIO_OPENAI_API_MODEL",
|
||||||
|
"audio.openai.api_model",
|
||||||
|
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
|
||||||
|
)
|
||||||
|
AUDIO_OPENAI_API_VOICE = WrappedConfig(
|
||||||
|
"AUDIO_OPENAI_API_VOICE",
|
||||||
|
"audio.openai.api_voice",
|
||||||
|
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# LiteLLM
|
# LiteLLM
|
||||||
@ -612,7 +789,6 @@ if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
|
|||||||
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
|
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
|
||||||
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
|
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Database
|
# Database
|
||||||
####################################
|
####################################
|
||||||
|
@ -58,6 +58,8 @@ from config import (
|
|||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
WEBHOOK_URL,
|
WEBHOOK_URL,
|
||||||
ENABLE_ADMIN_EXPORT,
|
ENABLE_ADMIN_EXPORT,
|
||||||
|
config_get,
|
||||||
|
config_set,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
@ -243,9 +245,11 @@ async def get_app_config():
|
|||||||
"version": VERSION,
|
"version": VERSION,
|
||||||
"auth": WEBUI_AUTH,
|
"auth": WEBUI_AUTH,
|
||||||
"default_locale": default_locale,
|
"default_locale": default_locale,
|
||||||
"images": images_app.state.ENABLED,
|
"images": config_get(images_app.state.ENABLED),
|
||||||
"default_models": webui_app.state.DEFAULT_MODELS,
|
"default_models": config_get(webui_app.state.DEFAULT_MODELS),
|
||||||
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
|
"default_prompt_suggestions": config_get(
|
||||||
|
webui_app.state.DEFAULT_PROMPT_SUGGESTIONS
|
||||||
|
),
|
||||||
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||||
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
|
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
|
||||||
}
|
}
|
||||||
@ -254,8 +258,8 @@ async def get_app_config():
|
|||||||
@app.get("/api/config/model/filter")
|
@app.get("/api/config/model/filter")
|
||||||
async def get_model_filter_config(user=Depends(get_admin_user)):
|
async def get_model_filter_config(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"enabled": app.state.ENABLE_MODEL_FILTER,
|
"enabled": config_get(app.state.ENABLE_MODEL_FILTER),
|
||||||
"models": app.state.MODEL_FILTER_LIST,
|
"models": config_get(app.state.MODEL_FILTER_LIST),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -268,28 +272,28 @@ class ModelFilterConfigForm(BaseModel):
|
|||||||
async def update_model_filter_config(
|
async def update_model_filter_config(
|
||||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
app.state.ENABLE_MODEL_FILTER = form_data.enabled
|
config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled)
|
||||||
app.state.MODEL_FILTER_LIST = form_data.models
|
config_set(app.state.MODEL_FILTER_LIST, form_data.models)
|
||||||
|
|
||||||
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
|
ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
|
||||||
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
|
||||||
|
|
||||||
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
|
openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
|
||||||
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
|
||||||
|
|
||||||
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
|
litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
|
||||||
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"enabled": app.state.ENABLE_MODEL_FILTER,
|
"enabled": config_get(app.state.ENABLE_MODEL_FILTER),
|
||||||
"models": app.state.MODEL_FILTER_LIST,
|
"models": config_get(app.state.MODEL_FILTER_LIST),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/webhook")
|
@app.get("/api/webhook")
|
||||||
async def get_webhook_url(user=Depends(get_admin_user)):
|
async def get_webhook_url(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"url": app.state.WEBHOOK_URL,
|
"url": config_get(app.state.WEBHOOK_URL),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -299,12 +303,12 @@ class UrlForm(BaseModel):
|
|||||||
|
|
||||||
@app.post("/api/webhook")
|
@app.post("/api/webhook")
|
||||||
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||||
app.state.WEBHOOK_URL = form_data.url
|
config_set(app.state.WEBHOOK_URL, form_data.url)
|
||||||
|
|
||||||
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
|
webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"url": app.state.WEBHOOK_URL,
|
"url": config_get(app.state.WEBHOOK_URL),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user