mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: switch to config proxy, remove config_get/set
This commit is contained in:
		
							parent
							
								
									f712c90019
								
							
						
					
					
						commit
						298e6848b3
					
				@ -45,8 +45,7 @@ from config import (
 | 
			
		||||
    AUDIO_OPENAI_API_KEY,
 | 
			
		||||
    AUDIO_OPENAI_API_MODEL,
 | 
			
		||||
    AUDIO_OPENAI_API_VOICE,
 | 
			
		||||
    config_get,
 | 
			
		||||
    config_set,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
@ -61,11 +60,11 @@ app.add_middleware(
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
 | 
			
		||||
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
 | 
			
		||||
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
 | 
			
		||||
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
 | 
			
		||||
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
 | 
			
		||||
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
 | 
			
		||||
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
 | 
			
		||||
 | 
			
		||||
# setting device type for whisper model
 | 
			
		||||
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
 | 
			
		||||
@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
 | 
			
		||||
@app.get("/config")
 | 
			
		||||
async def get_openai_config(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
        "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
        "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL),
 | 
			
		||||
        "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE),
 | 
			
		||||
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
 | 
			
		||||
        "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
 | 
			
		||||
        "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -99,22 +98,17 @@ async def update_openai_config(
 | 
			
		||||
    if form_data.key == "":
 | 
			
		||||
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 | 
			
		||||
 | 
			
		||||
    config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
 | 
			
		||||
    config_set(app.state.OPENAI_API_KEY, form_data.key)
 | 
			
		||||
    config_set(app.state.OPENAI_API_MODEL, form_data.model)
 | 
			
		||||
    config_set(app.state.OPENAI_API_VOICE, form_data.speaker)
 | 
			
		||||
 | 
			
		||||
    app.state.OPENAI_API_BASE_URL.save()
 | 
			
		||||
    app.state.OPENAI_API_KEY.save()
 | 
			
		||||
    app.state.OPENAI_API_MODEL.save()
 | 
			
		||||
    app.state.OPENAI_API_VOICE.save()
 | 
			
		||||
    app.state.config.OPENAI_API_BASE_URL = form_data.url
 | 
			
		||||
    app.state.config.OPENAI_API_KEY = form_data.key
 | 
			
		||||
    app.state.config.OPENAI_API_MODEL = form_data.model
 | 
			
		||||
    app.state.config.OPENAI_API_VOICE = form_data.speaker
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
        "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
        "OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL),
 | 
			
		||||
        "OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE),
 | 
			
		||||
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
 | 
			
		||||
        "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
 | 
			
		||||
        "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
        return FileResponse(file_path)
 | 
			
		||||
 | 
			
		||||
    headers = {}
 | 
			
		||||
    headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
 | 
			
		||||
    headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
 | 
			
		||||
    headers["Content-Type"] = "application/json"
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
    try:
 | 
			
		||||
        r = requests.post(
 | 
			
		||||
            url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech",
 | 
			
		||||
            url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
 | 
			
		||||
            data=body,
 | 
			
		||||
            headers=headers,
 | 
			
		||||
            stream=True,
 | 
			
		||||
 | 
			
		||||
@ -42,8 +42,7 @@ from config import (
 | 
			
		||||
    IMAGE_GENERATION_MODEL,
 | 
			
		||||
    IMAGE_SIZE,
 | 
			
		||||
    IMAGE_STEPS,
 | 
			
		||||
    config_get,
 | 
			
		||||
    config_set,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -62,28 +61,30 @@ app.add_middleware(
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.state.ENGINE = IMAGE_GENERATION_ENGINE
 | 
			
		||||
app.state.ENABLED = ENABLE_IMAGE_GENERATION
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
 | 
			
		||||
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
 | 
			
		||||
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
 | 
			
		||||
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
 | 
			
		||||
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
 | 
			
		||||
 | 
			
		||||
app.state.MODEL = IMAGE_GENERATION_MODEL
 | 
			
		||||
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
 | 
			
		||||
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
 | 
			
		||||
 | 
			
		||||
app.state.config.MODEL = IMAGE_GENERATION_MODEL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 | 
			
		||||
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
 | 
			
		||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 | 
			
		||||
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.IMAGE_SIZE = IMAGE_SIZE
 | 
			
		||||
app.state.IMAGE_STEPS = IMAGE_STEPS
 | 
			
		||||
app.state.config.IMAGE_SIZE = IMAGE_SIZE
 | 
			
		||||
app.state.config.IMAGE_STEPS = IMAGE_STEPS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/config")
 | 
			
		||||
async def get_config(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "engine": config_get(app.state.ENGINE),
 | 
			
		||||
        "enabled": config_get(app.state.ENABLED),
 | 
			
		||||
        "engine": app.state.config.ENGINE,
 | 
			
		||||
        "enabled": app.state.config.ENABLED,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.post("/config/update")
 | 
			
		||||
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
 | 
			
		||||
    config_set(app.state.ENGINE, form_data.engine)
 | 
			
		||||
    config_set(app.state.ENABLED, form_data.enabled)
 | 
			
		||||
    app.state.config.ENGINE = form_data.engine
 | 
			
		||||
    app.state.config.ENABLED = form_data.enabled
 | 
			
		||||
    return {
 | 
			
		||||
        "engine": config_get(app.state.ENGINE),
 | 
			
		||||
        "enabled": config_get(app.state.ENABLED),
 | 
			
		||||
        "engine": app.state.config.ENGINE,
 | 
			
		||||
        "enabled": app.state.config.ENABLED,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
 | 
			
		||||
@app.get("/url")
 | 
			
		||||
async def get_engine_url(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
 | 
			
		||||
        "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
 | 
			
		||||
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
 | 
			
		||||
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -121,29 +122,29 @@ async def update_engine_url(
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    if form_data.AUTOMATIC1111_BASE_URL == None:
 | 
			
		||||
        config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL))
 | 
			
		||||
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 | 
			
		||||
    else:
 | 
			
		||||
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
 | 
			
		||||
        try:
 | 
			
		||||
            r = requests.head(url)
 | 
			
		||||
            config_set(app.state.AUTOMATIC1111_BASE_URL, url)
 | 
			
		||||
            app.state.config.AUTOMATIC1111_BASE_URL = url
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 | 
			
		||||
 | 
			
		||||
    if form_data.COMFYUI_BASE_URL == None:
 | 
			
		||||
        config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL)
 | 
			
		||||
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
 | 
			
		||||
    else:
 | 
			
		||||
        url = form_data.COMFYUI_BASE_URL.strip("/")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            r = requests.head(url)
 | 
			
		||||
            config_set(app.state.COMFYUI_BASE_URL, url)
 | 
			
		||||
            app.state.config.COMFYUI_BASE_URL = url
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
 | 
			
		||||
        "COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
 | 
			
		||||
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
 | 
			
		||||
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
 | 
			
		||||
        "status": True,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
 | 
			
		||||
@app.get("/openai/config")
 | 
			
		||||
async def get_openai_config(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
        "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -168,13 +169,13 @@ async def update_openai_config(
 | 
			
		||||
    if form_data.key == "":
 | 
			
		||||
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 | 
			
		||||
 | 
			
		||||
    config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
 | 
			
		||||
    config_set(app.state.OPENAI_API_KEY, form_data.key)
 | 
			
		||||
    app.state.config.OPENAI_API_BASE_URL = form_data.url
 | 
			
		||||
    app.state.config.OPENAI_API_KEY = form_data.key
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
        "OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.get("/size")
 | 
			
		||||
async def get_image_size(user=Depends(get_admin_user)):
 | 
			
		||||
    return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)}
 | 
			
		||||
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/size/update")
 | 
			
		||||
@ -193,9 +194,9 @@ async def update_image_size(
 | 
			
		||||
):
 | 
			
		||||
    pattern = r"^\d+x\d+$"  # Regular expression pattern
 | 
			
		||||
    if re.match(pattern, form_data.size):
 | 
			
		||||
        config_set(app.state.IMAGE_SIZE, form_data.size)
 | 
			
		||||
        app.state.config.IMAGE_SIZE = form_data.size
 | 
			
		||||
        return {
 | 
			
		||||
            "IMAGE_SIZE": config_get(app.state.IMAGE_SIZE),
 | 
			
		||||
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
 | 
			
		||||
            "status": True,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.get("/steps")
 | 
			
		||||
async def get_image_size(user=Depends(get_admin_user)):
 | 
			
		||||
    return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)}
 | 
			
		||||
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/steps/update")
 | 
			
		||||
@ -219,9 +220,9 @@ async def update_image_size(
 | 
			
		||||
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    if form_data.steps >= 0:
 | 
			
		||||
        config_set(app.state.IMAGE_STEPS, form_data.steps)
 | 
			
		||||
        app.state.config.IMAGE_STEPS = form_data.steps
 | 
			
		||||
        return {
 | 
			
		||||
            "IMAGE_STEPS": config_get(app.state.IMAGE_STEPS),
 | 
			
		||||
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
 | 
			
		||||
            "status": True,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
@ -234,14 +235,14 @@ async def update_image_size(
 | 
			
		||||
@app.get("/models")
 | 
			
		||||
def get_models(user=Depends(get_current_user)):
 | 
			
		||||
    try:
 | 
			
		||||
        if app.state.ENGINE == "openai":
 | 
			
		||||
        if app.state.config.ENGINE == "openai":
 | 
			
		||||
            return [
 | 
			
		||||
                {"id": "dall-e-2", "name": "DALL·E 2"},
 | 
			
		||||
                {"id": "dall-e-3", "name": "DALL·E 3"},
 | 
			
		||||
            ]
 | 
			
		||||
        elif app.state.ENGINE == "comfyui":
 | 
			
		||||
        elif app.state.config.ENGINE == "comfyui":
 | 
			
		||||
 | 
			
		||||
            r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
 | 
			
		||||
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
 | 
			
		||||
            info = r.json()
 | 
			
		||||
 | 
			
		||||
            return list(
 | 
			
		||||
@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            r = requests.get(
 | 
			
		||||
                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
 | 
			
		||||
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
 | 
			
		||||
            )
 | 
			
		||||
            models = r.json()
 | 
			
		||||
            return list(
 | 
			
		||||
@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)):
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        app.state.ENABLED = False
 | 
			
		||||
        app.state.config.ENABLED = False
 | 
			
		||||
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/models/default")
 | 
			
		||||
async def get_default_model(user=Depends(get_admin_user)):
 | 
			
		||||
    try:
 | 
			
		||||
        if app.state.ENGINE == "openai":
 | 
			
		||||
        if app.state.config.ENGINE == "openai":
 | 
			
		||||
            return {
 | 
			
		||||
                "model": (
 | 
			
		||||
                    config_get(app.state.MODEL)
 | 
			
		||||
                    if config_get(app.state.MODEL)
 | 
			
		||||
                    else "dall-e-2"
 | 
			
		||||
                )
 | 
			
		||||
            }
 | 
			
		||||
        elif app.state.ENGINE == "comfyui":
 | 
			
		||||
            return {
 | 
			
		||||
                "model": (
 | 
			
		||||
                    config_get(app.state.MODEL) if config_get(app.state.MODEL) else ""
 | 
			
		||||
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
 | 
			
		||||
                )
 | 
			
		||||
            }
 | 
			
		||||
        elif app.state.config.ENGINE == "comfyui":
 | 
			
		||||
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
 | 
			
		||||
        else:
 | 
			
		||||
            r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
 | 
			
		||||
            r = requests.get(
 | 
			
		||||
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
 | 
			
		||||
            )
 | 
			
		||||
            options = r.json()
 | 
			
		||||
            return {"model": options["sd_model_checkpoint"]}
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        config_set(app.state.ENABLED, False)
 | 
			
		||||
        app.state.config.ENABLED = False
 | 
			
		||||
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_model_handler(model: str):
 | 
			
		||||
    if app.state.ENGINE in ["openai", "comfyui"]:
 | 
			
		||||
        config_set(app.state.MODEL, model)
 | 
			
		||||
        return config_get(app.state.MODEL)
 | 
			
		||||
    if app.state.config.ENGINE in ["openai", "comfyui"]:
 | 
			
		||||
        app.state.config.MODEL = model
 | 
			
		||||
        return app.state.config.MODEL
 | 
			
		||||
    else:
 | 
			
		||||
        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
 | 
			
		||||
        r = requests.get(
 | 
			
		||||
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
 | 
			
		||||
        )
 | 
			
		||||
        options = r.json()
 | 
			
		||||
 | 
			
		||||
        if model != options["sd_model_checkpoint"]:
 | 
			
		||||
            options["sd_model_checkpoint"] = model
 | 
			
		||||
            r = requests.post(
 | 
			
		||||
                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
 | 
			
		||||
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
 | 
			
		||||
                json=options,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return options
 | 
			
		||||
@ -397,30 +397,32 @@ def generate_image(
 | 
			
		||||
    user=Depends(get_current_user),
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x")))
 | 
			
		||||
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
    try:
 | 
			
		||||
        if app.state.ENGINE == "openai":
 | 
			
		||||
        if app.state.config.ENGINE == "openai":
 | 
			
		||||
 | 
			
		||||
            headers = {}
 | 
			
		||||
            headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
 | 
			
		||||
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
 | 
			
		||||
            headers["Content-Type"] = "application/json"
 | 
			
		||||
 | 
			
		||||
            data = {
 | 
			
		||||
                "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
 | 
			
		||||
                "model": (
 | 
			
		||||
                    app.state.config.MODEL
 | 
			
		||||
                    if app.state.config.MODEL != ""
 | 
			
		||||
                    else "dall-e-2"
 | 
			
		||||
                ),
 | 
			
		||||
                "prompt": form_data.prompt,
 | 
			
		||||
                "n": form_data.n,
 | 
			
		||||
                "size": (
 | 
			
		||||
                    form_data.size
 | 
			
		||||
                    if form_data.size
 | 
			
		||||
                    else config_get(app.state.IMAGE_SIZE)
 | 
			
		||||
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
 | 
			
		||||
                ),
 | 
			
		||||
                "response_format": "b64_json",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            r = requests.post(
 | 
			
		||||
                url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
 | 
			
		||||
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
 | 
			
		||||
                json=data,
 | 
			
		||||
                headers=headers,
 | 
			
		||||
            )
 | 
			
		||||
@ -440,7 +442,7 @@ def generate_image(
 | 
			
		||||
 | 
			
		||||
            return images
 | 
			
		||||
 | 
			
		||||
        elif app.state.ENGINE == "comfyui":
 | 
			
		||||
        elif app.state.config.ENGINE == "comfyui":
 | 
			
		||||
 | 
			
		||||
            data = {
 | 
			
		||||
                "prompt": form_data.prompt,
 | 
			
		||||
@ -449,8 +451,8 @@ def generate_image(
 | 
			
		||||
                "n": form_data.n,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if config_get(app.state.IMAGE_STEPS) is not None:
 | 
			
		||||
                data["steps"] = config_get(app.state.IMAGE_STEPS)
 | 
			
		||||
            if app.state.config.IMAGE_STEPS is not None:
 | 
			
		||||
                data["steps"] = app.state.config.IMAGE_STEPS
 | 
			
		||||
 | 
			
		||||
            if form_data.negative_prompt is not None:
 | 
			
		||||
                data["negative_prompt"] = form_data.negative_prompt
 | 
			
		||||
@ -458,10 +460,10 @@ def generate_image(
 | 
			
		||||
            data = ImageGenerationPayload(**data)
 | 
			
		||||
 | 
			
		||||
            res = comfyui_generate_image(
 | 
			
		||||
                config_get(app.state.MODEL),
 | 
			
		||||
                app.state.config.MODEL,
 | 
			
		||||
                data,
 | 
			
		||||
                user.id,
 | 
			
		||||
                config_get(app.state.COMFYUI_BASE_URL),
 | 
			
		||||
                app.state.config.COMFYUI_BASE_URL,
 | 
			
		||||
            )
 | 
			
		||||
            log.debug(f"res: {res}")
 | 
			
		||||
 | 
			
		||||
@ -488,14 +490,14 @@ def generate_image(
 | 
			
		||||
                "height": height,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if config_get(app.state.IMAGE_STEPS) is not None:
 | 
			
		||||
                data["steps"] = config_get(app.state.IMAGE_STEPS)
 | 
			
		||||
            if app.state.config.IMAGE_STEPS is not None:
 | 
			
		||||
                data["steps"] = app.state.config.IMAGE_STEPS
 | 
			
		||||
 | 
			
		||||
            if form_data.negative_prompt is not None:
 | 
			
		||||
                data["negative_prompt"] = form_data.negative_prompt
 | 
			
		||||
 | 
			
		||||
            r = requests.post(
 | 
			
		||||
                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
 | 
			
		||||
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
 | 
			
		||||
                json=data,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,8 +46,7 @@ from config import (
 | 
			
		||||
    ENABLE_MODEL_FILTER,
 | 
			
		||||
    MODEL_FILTER_LIST,
 | 
			
		||||
    UPLOAD_DIR,
 | 
			
		||||
    config_set,
 | 
			
		||||
    config_get,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
from utils.misc import calculate_sha256
 | 
			
		||||
 | 
			
		||||
@ -63,11 +62,12 @@ app.add_middleware(
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
 | 
			
		||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 | 
			
		||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 | 
			
		||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 | 
			
		||||
app.state.MODELS = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -98,7 +98,7 @@ async def get_status():
 | 
			
		||||
 | 
			
		||||
@app.get("/urls")
 | 
			
		||||
async def get_ollama_api_urls(user=Depends(get_admin_user)):
 | 
			
		||||
    return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)}
 | 
			
		||||
    return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UrlUpdateForm(BaseModel):
 | 
			
		||||
@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.post("/urls/update")
 | 
			
		||||
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
 | 
			
		||||
    config_set(app.state.OLLAMA_BASE_URLS, form_data.urls)
 | 
			
		||||
    app.state.config.OLLAMA_BASE_URLS = form_data.urls
 | 
			
		||||
 | 
			
		||||
    log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
 | 
			
		||||
    return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)}
 | 
			
		||||
    log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
 | 
			
		||||
    return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/cancel/{request_id}")
 | 
			
		||||
@ -155,9 +155,7 @@ def merge_models_lists(model_lists):
 | 
			
		||||
 | 
			
		||||
async def get_all_models():
 | 
			
		||||
    log.info("get_all_models()")
 | 
			
		||||
    tasks = [
 | 
			
		||||
        fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS)
 | 
			
		||||
    ]
 | 
			
		||||
    tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
 | 
			
		||||
    responses = await asyncio.gather(*tasks)
 | 
			
		||||
 | 
			
		||||
    models = {
 | 
			
		||||
@ -183,15 +181,14 @@ async def get_ollama_tags(
 | 
			
		||||
            if user.role == "user":
 | 
			
		||||
                models["models"] = list(
 | 
			
		||||
                    filter(
 | 
			
		||||
                        lambda model: model["name"]
 | 
			
		||||
                        in config_get(app.state.MODEL_FILTER_LIST),
 | 
			
		||||
                        lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
 | 
			
		||||
                        models["models"],
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                return models
 | 
			
		||||
        return models
 | 
			
		||||
    else:
 | 
			
		||||
        url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
        try:
 | 
			
		||||
            r = requests.request(method="GET", url=f"{url}/api/tags")
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
 | 
			
		||||
 | 
			
		||||
        # returns lowest version
 | 
			
		||||
        tasks = [
 | 
			
		||||
            fetch_url(f"{url}/api/version")
 | 
			
		||||
            for url in config_get(app.state.OLLAMA_BASE_URLS)
 | 
			
		||||
            fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
 | 
			
		||||
        ]
 | 
			
		||||
        responses = await asyncio.gather(*tasks)
 | 
			
		||||
        responses = list(filter(lambda x: x is not None, responses))
 | 
			
		||||
@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
 | 
			
		||||
                detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
        try:
 | 
			
		||||
            r = requests.request(method="GET", url=f"{url}/api/version")
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
@ -275,7 +271,7 @@ class ModelNameForm(BaseModel):
 | 
			
		||||
async def pull_model(
 | 
			
		||||
    form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
@ -363,7 +359,7 @@ async def push_model(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.debug(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
@ -425,7 +421,7 @@ async def create_model(
 | 
			
		||||
    form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    log.debug(f"form_data: {form_data}")
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
@ -498,7 +494,7 @@ async def copy_model(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -545,7 +541,7 @@ async def delete_model(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -642,7 +638,7 @@ async def generate_embeddings(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -692,7 +688,7 @@ def generate_ollama_embeddings(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -761,7 +757,7 @@ async def generate_completion(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
@ -864,7 +860,7 @@ async def generate_chat_completion(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
@ -973,7 +969,7 @@ async def generate_openai_chat_completion(
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
@ -1072,7 +1068,7 @@ async def get_openai_models(
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
        try:
 | 
			
		||||
            r = requests.request(method="GET", url=f"{url}/api/tags")
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
@ -1206,7 +1202,7 @@ async def download_model(
 | 
			
		||||
 | 
			
		||||
    if url_idx == None:
 | 
			
		||||
        url_idx = 0
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
 | 
			
		||||
    file_name = parse_huggingface_url(form_data.url)
 | 
			
		||||
 | 
			
		||||
@ -1225,7 +1221,7 @@ async def download_model(
 | 
			
		||||
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
 | 
			
		||||
    if url_idx == None:
 | 
			
		||||
        url_idx = 0
 | 
			
		||||
    ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
    ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
 | 
			
		||||
    file_path = f"{UPLOAD_DIR}/{file.filename}"
 | 
			
		||||
 | 
			
		||||
@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
 | 
			
		||||
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
 | 
			
		||||
#     if url_idx == None:
 | 
			
		||||
#         url_idx = 0
 | 
			
		||||
#     url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx]
 | 
			
		||||
#     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
 | 
			
		||||
#     file_location = os.path.join(UPLOAD_DIR, file.filename)
 | 
			
		||||
#     total_size = file.size
 | 
			
		||||
@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
 | 
			
		||||
async def deprecated_proxy(
 | 
			
		||||
    path: str, request: Request, user=Depends(get_verified_user)
 | 
			
		||||
):
 | 
			
		||||
    url = config_get(app.state.OLLAMA_BASE_URLS)[0]
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[0]
 | 
			
		||||
    target_url = f"{url}/{path}"
 | 
			
		||||
 | 
			
		||||
    body = await request.body()
 | 
			
		||||
 | 
			
		||||
@ -26,8 +26,7 @@ from config import (
 | 
			
		||||
    CACHE_DIR,
 | 
			
		||||
    ENABLE_MODEL_FILTER,
 | 
			
		||||
    MODEL_FILTER_LIST,
 | 
			
		||||
    config_set,
 | 
			
		||||
    config_get,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
 | 
			
		||||
@ -47,11 +46,13 @@ app.add_middleware(
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
 | 
			
		||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 | 
			
		||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 | 
			
		||||
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 | 
			
		||||
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 | 
			
		||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
 | 
			
		||||
 | 
			
		||||
app.state.MODELS = {}
 | 
			
		||||
 | 
			
		||||
@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.get("/urls")
 | 
			
		||||
async def get_openai_urls(user=Depends(get_admin_user)):
 | 
			
		||||
    return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
 | 
			
		||||
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/urls/update")
 | 
			
		||||
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
 | 
			
		||||
    await get_all_models()
 | 
			
		||||
    config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls)
 | 
			
		||||
    return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)}
 | 
			
		||||
    app.state.config.OPENAI_API_BASE_URLS = form_data.urls
 | 
			
		||||
    return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/keys")
 | 
			
		||||
async def get_openai_keys(user=Depends(get_admin_user)):
 | 
			
		||||
    return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
 | 
			
		||||
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/keys/update")
 | 
			
		||||
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
 | 
			
		||||
    config_set(app.state.OPENAI_API_KEYS, form_data.keys)
 | 
			
		||||
    return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)}
 | 
			
		||||
    app.state.config.OPENAI_API_KEYS = form_data.keys
 | 
			
		||||
    return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/audio/speech")
 | 
			
		||||
async def speech(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
    idx = None
 | 
			
		||||
    try:
 | 
			
		||||
        idx = config_get(app.state.OPENAI_API_BASE_URLS).index(
 | 
			
		||||
            "https://api.openai.com/v1"
 | 
			
		||||
        )
 | 
			
		||||
        idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
 | 
			
		||||
        body = await request.body()
 | 
			
		||||
        name = hashlib.sha256(body).hexdigest()
 | 
			
		||||
 | 
			
		||||
@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
            return FileResponse(file_path)
 | 
			
		||||
 | 
			
		||||
        headers = {}
 | 
			
		||||
        headers["Authorization"] = (
 | 
			
		||||
            f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}"
 | 
			
		||||
        )
 | 
			
		||||
        headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
 | 
			
		||||
        headers["Content-Type"] = "application/json"
 | 
			
		||||
 | 
			
		||||
        r = None
 | 
			
		||||
        try:
 | 
			
		||||
            r = requests.post(
 | 
			
		||||
                url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech",
 | 
			
		||||
                url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
 | 
			
		||||
                data=body,
 | 
			
		||||
                headers=headers,
 | 
			
		||||
                stream=True,
 | 
			
		||||
@ -187,7 +184,7 @@ def merge_models_lists(model_lists):
 | 
			
		||||
                    {**model, "urlIdx": idx}
 | 
			
		||||
                    for model in models
 | 
			
		||||
                    if "api.openai.com"
 | 
			
		||||
                    not in config_get(app.state.OPENAI_API_BASE_URLS)[idx]
 | 
			
		||||
                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
 | 
			
		||||
                    or "gpt" in model["id"]
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
@ -199,14 +196,14 @@ async def get_all_models():
 | 
			
		||||
    log.info("get_all_models()")
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        len(config_get(app.state.OPENAI_API_KEYS)) == 1
 | 
			
		||||
        and config_get(app.state.OPENAI_API_KEYS)[0] == ""
 | 
			
		||||
        len(app.state.config.OPENAI_API_KEYS) == 1
 | 
			
		||||
        and app.state.config.OPENAI_API_KEYS[0] == ""
 | 
			
		||||
    ):
 | 
			
		||||
        models = {"data": []}
 | 
			
		||||
    else:
 | 
			
		||||
        tasks = [
 | 
			
		||||
            fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx])
 | 
			
		||||
            for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS))
 | 
			
		||||
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
 | 
			
		||||
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        responses = await asyncio.gather(*tasks)
 | 
			
		||||
@ -238,19 +235,18 @@ async def get_all_models():
 | 
			
		||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
 | 
			
		||||
    if url_idx == None:
 | 
			
		||||
        models = await get_all_models()
 | 
			
		||||
        if config_get(app.state.ENABLE_MODEL_FILTER):
 | 
			
		||||
        if app.state.ENABLE_MODEL_FILTER:
 | 
			
		||||
            if user.role == "user":
 | 
			
		||||
                models["data"] = list(
 | 
			
		||||
                    filter(
 | 
			
		||||
                        lambda model: model["id"]
 | 
			
		||||
                        in config_get(app.state.MODEL_FILTER_LIST),
 | 
			
		||||
                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
 | 
			
		||||
                        models["data"],
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                return models
 | 
			
		||||
        return models
 | 
			
		||||
    else:
 | 
			
		||||
        url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx]
 | 
			
		||||
        url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
 | 
			
		||||
 | 
			
		||||
        r = None
 | 
			
		||||
 | 
			
		||||
@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
    except json.JSONDecodeError as e:
 | 
			
		||||
        log.error("Error loading request body into a dictionary:", e)
 | 
			
		||||
 | 
			
		||||
    url = config_get(app.state.OPENAI_API_BASE_URLS)[idx]
 | 
			
		||||
    key = config_get(app.state.OPENAI_API_KEYS)[idx]
 | 
			
		||||
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
 | 
			
		||||
    key = app.state.config.OPENAI_API_KEYS[idx]
 | 
			
		||||
 | 
			
		||||
    target_url = f"{url}/{path}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -93,8 +93,7 @@ from config import (
 | 
			
		||||
    RAG_TEMPLATE,
 | 
			
		||||
    ENABLE_RAG_LOCAL_WEB_FETCH,
 | 
			
		||||
    YOUTUBE_LOADER_LANGUAGE,
 | 
			
		||||
    config_set,
 | 
			
		||||
    config_get,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
@ -104,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
app.state.TOP_K = RAG_TOP_K
 | 
			
		||||
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
 | 
			
		||||
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
 | 
			
		||||
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 | 
			
		||||
app.state.config.TOP_K = RAG_TOP_K
 | 
			
		||||
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 | 
			
		||||
 | 
			
		||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
 | 
			
		||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 | 
			
		||||
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.state.CHUNK_SIZE = CHUNK_SIZE
 | 
			
		||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 | 
			
		||||
app.state.config.CHUNK_SIZE = CHUNK_SIZE
 | 
			
		||||
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
 | 
			
		||||
 | 
			
		||||
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 | 
			
		||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 | 
			
		||||
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 | 
			
		||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
 | 
			
		||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 | 
			
		||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 | 
			
		||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 | 
			
		||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 | 
			
		||||
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 | 
			
		||||
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 | 
			
		||||
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 | 
			
		||||
 | 
			
		||||
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
 | 
			
		||||
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
 | 
			
		||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
 | 
			
		||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -135,7 +136,7 @@ def update_embedding_model(
 | 
			
		||||
    embedding_model: str,
 | 
			
		||||
    update_model: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "":
 | 
			
		||||
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
 | 
			
		||||
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
 | 
			
		||||
            get_model_path(embedding_model, update_model),
 | 
			
		||||
            device=DEVICE_TYPE,
 | 
			
		||||
@ -160,22 +161,22 @@ def update_reranking_model(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
update_embedding_model(
 | 
			
		||||
    config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
    app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
update_reranking_model(
 | 
			
		||||
    config_get(app.state.RAG_RERANKING_MODEL),
 | 
			
		||||
    app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
    RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
 | 
			
		||||
    config_get(app.state.RAG_EMBEDDING_ENGINE),
 | 
			
		||||
    config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
    app.state.config.RAG_EMBEDDING_ENGINE,
 | 
			
		||||
    app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
    app.state.sentence_transformer_ef,
 | 
			
		||||
    config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
    config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
    app.state.config.OPENAI_API_KEY,
 | 
			
		||||
    app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
origins = ["*"]
 | 
			
		||||
@ -202,12 +203,12 @@ class UrlForm(CollectionNameForm):
 | 
			
		||||
async def get_status():
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "chunk_size": config_get(app.state.CHUNK_SIZE),
 | 
			
		||||
        "chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
 | 
			
		||||
        "template": config_get(app.state.RAG_TEMPLATE),
 | 
			
		||||
        "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
 | 
			
		||||
        "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
        "reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
 | 
			
		||||
        "chunk_size": app.state.config.CHUNK_SIZE,
 | 
			
		||||
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
 | 
			
		||||
        "template": app.state.config.RAG_TEMPLATE,
 | 
			
		||||
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
 | 
			
		||||
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -215,11 +216,11 @@ async def get_status():
 | 
			
		||||
async def get_embedding_config(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
 | 
			
		||||
        "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
 | 
			
		||||
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
        "openai_config": {
 | 
			
		||||
            "url": config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
            "key": config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
            "url": app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
            "key": app.state.config.OPENAI_API_KEY,
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -228,7 +229,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
 | 
			
		||||
async def get_reraanking_config(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
 | 
			
		||||
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -248,34 +249,34 @@ async def update_embedding_config(
 | 
			
		||||
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    log.info(
 | 
			
		||||
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
 | 
			
		||||
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
 | 
			
		||||
    )
 | 
			
		||||
    try:
 | 
			
		||||
        config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine)
 | 
			
		||||
        config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model)
 | 
			
		||||
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
 | 
			
		||||
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
 | 
			
		||||
 | 
			
		||||
        if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]:
 | 
			
		||||
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
 | 
			
		||||
            if form_data.openai_config != None:
 | 
			
		||||
                config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url)
 | 
			
		||||
                config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key)
 | 
			
		||||
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
 | 
			
		||||
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
 | 
			
		||||
 | 
			
		||||
        update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True)
 | 
			
		||||
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
 | 
			
		||||
 | 
			
		||||
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
 | 
			
		||||
            config_get(app.state.RAG_EMBEDDING_ENGINE),
 | 
			
		||||
            config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
            app.state.config.RAG_EMBEDDING_ENGINE,
 | 
			
		||||
            app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
            app.state.sentence_transformer_ef,
 | 
			
		||||
            config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
            config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
            app.state.config.OPENAI_API_KEY,
 | 
			
		||||
            app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
            "embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
 | 
			
		||||
            "embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
 | 
			
		||||
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
            "openai_config": {
 | 
			
		||||
                "url": config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
                "key": config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
                "url": app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
                "key": app.state.config.OPENAI_API_KEY,
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
@ -295,16 +296,16 @@ async def update_reranking_config(
 | 
			
		||||
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    log.info(
 | 
			
		||||
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
 | 
			
		||||
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
 | 
			
		||||
    )
 | 
			
		||||
    try:
 | 
			
		||||
        config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model)
 | 
			
		||||
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
 | 
			
		||||
 | 
			
		||||
        update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True)
 | 
			
		||||
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
            "reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
 | 
			
		||||
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
        }
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(f"Problem updating reranking model: {e}")
 | 
			
		||||
@ -318,16 +319,14 @@ async def update_reranking_config(
 | 
			
		||||
async def get_rag_config(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
 | 
			
		||||
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
 | 
			
		||||
        "chunk": {
 | 
			
		||||
            "chunk_size": config_get(app.state.CHUNK_SIZE),
 | 
			
		||||
            "chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
 | 
			
		||||
            "chunk_size": app.state.config.CHUNK_SIZE,
 | 
			
		||||
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
 | 
			
		||||
        },
 | 
			
		||||
        "web_loader_ssl_verification": config_get(
 | 
			
		||||
            app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
 | 
			
		||||
        ),
 | 
			
		||||
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
			
		||||
        "youtube": {
 | 
			
		||||
            "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
 | 
			
		||||
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
 | 
			
		||||
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
@ -352,49 +351,34 @@ class ConfigUpdateForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.post("/config/update")
 | 
			
		||||
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.PDF_EXTRACT_IMAGES,
 | 
			
		||||
        (
 | 
			
		||||
            form_data.pdf_extract_images
 | 
			
		||||
            if form_data.pdf_extract_images is not None
 | 
			
		||||
            else config_get(app.state.PDF_EXTRACT_IMAGES)
 | 
			
		||||
        ),
 | 
			
		||||
    app.state.config.PDF_EXTRACT_IMAGES = (
 | 
			
		||||
        form_data.pdf_extract_images
 | 
			
		||||
        if form_data.pdf_extract_images is not None
 | 
			
		||||
        else app.state.config.PDF_EXTRACT_IMAGES
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.CHUNK_SIZE,
 | 
			
		||||
        (
 | 
			
		||||
            form_data.chunk.chunk_size
 | 
			
		||||
            if form_data.chunk is not None
 | 
			
		||||
            else config_get(app.state.CHUNK_SIZE)
 | 
			
		||||
        ),
 | 
			
		||||
    app.state.config.CHUNK_SIZE = (
 | 
			
		||||
        form_data.chunk.chunk_size
 | 
			
		||||
        if form_data.chunk is not None
 | 
			
		||||
        else app.state.config.CHUNK_SIZE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.CHUNK_OVERLAP,
 | 
			
		||||
        (
 | 
			
		||||
            form_data.chunk.chunk_overlap
 | 
			
		||||
            if form_data.chunk is not None
 | 
			
		||||
            else config_get(app.state.CHUNK_OVERLAP)
 | 
			
		||||
        ),
 | 
			
		||||
    app.state.config.CHUNK_OVERLAP = (
 | 
			
		||||
        form_data.chunk.chunk_overlap
 | 
			
		||||
        if form_data.chunk is not None
 | 
			
		||||
        else app.state.config.CHUNK_OVERLAP
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
			
		||||
        (
 | 
			
		||||
            form_data.web_loader_ssl_verification
 | 
			
		||||
            if form_data.web_loader_ssl_verification != None
 | 
			
		||||
            else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION)
 | 
			
		||||
        ),
 | 
			
		||||
    app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 | 
			
		||||
        form_data.web_loader_ssl_verification
 | 
			
		||||
        if form_data.web_loader_ssl_verification != None
 | 
			
		||||
        else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.YOUTUBE_LOADER_LANGUAGE,
 | 
			
		||||
        (
 | 
			
		||||
            form_data.youtube.language
 | 
			
		||||
            if form_data.youtube is not None
 | 
			
		||||
            else config_get(app.state.YOUTUBE_LOADER_LANGUAGE)
 | 
			
		||||
        ),
 | 
			
		||||
    app.state.config.YOUTUBE_LOADER_LANGUAGE = (
 | 
			
		||||
        form_data.youtube.language
 | 
			
		||||
        if form_data.youtube is not None
 | 
			
		||||
        else app.state.config.YOUTUBE_LOADER_LANGUAGE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    app.state.YOUTUBE_LOADER_TRANSLATION = (
 | 
			
		||||
@ -405,16 +389,14 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
 | 
			
		||||
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
 | 
			
		||||
        "chunk": {
 | 
			
		||||
            "chunk_size": config_get(app.state.CHUNK_SIZE),
 | 
			
		||||
            "chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
 | 
			
		||||
            "chunk_size": app.state.config.CHUNK_SIZE,
 | 
			
		||||
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
 | 
			
		||||
        },
 | 
			
		||||
        "web_loader_ssl_verification": config_get(
 | 
			
		||||
            app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
 | 
			
		||||
        ),
 | 
			
		||||
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
			
		||||
        "youtube": {
 | 
			
		||||
            "language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
 | 
			
		||||
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
 | 
			
		||||
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
@ -424,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
 | 
			
		||||
async def get_rag_template(user=Depends(get_current_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "template": config_get(app.state.RAG_TEMPLATE),
 | 
			
		||||
        "template": app.state.config.RAG_TEMPLATE,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -432,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
 | 
			
		||||
async def get_query_settings(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "template": config_get(app.state.RAG_TEMPLATE),
 | 
			
		||||
        "k": config_get(app.state.TOP_K),
 | 
			
		||||
        "r": config_get(app.state.RELEVANCE_THRESHOLD),
 | 
			
		||||
        "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
 | 
			
		||||
        "template": app.state.config.RAG_TEMPLATE,
 | 
			
		||||
        "k": app.state.config.TOP_K,
 | 
			
		||||
        "r": app.state.config.RELEVANCE_THRESHOLD,
 | 
			
		||||
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -450,22 +432,20 @@ class QuerySettingsForm(BaseModel):
 | 
			
		||||
async def update_query_settings(
 | 
			
		||||
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.RAG_TEMPLATE,
 | 
			
		||||
    app.state.config.RAG_TEMPLATE = (
 | 
			
		||||
        form_data.template if form_data.template else RAG_TEMPLATE,
 | 
			
		||||
    )
 | 
			
		||||
    config_set(app.state.TOP_K, form_data.k if form_data.k else 4)
 | 
			
		||||
    config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0)
 | 
			
		||||
    config_set(
 | 
			
		||||
        app.state.ENABLE_RAG_HYBRID_SEARCH,
 | 
			
		||||
    app.state.config.TOP_K = form_data.k if form_data.k else 4
 | 
			
		||||
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
 | 
			
		||||
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
 | 
			
		||||
        form_data.hybrid if form_data.hybrid else False,
 | 
			
		||||
    )
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "template": config_get(app.state.RAG_TEMPLATE),
 | 
			
		||||
        "k": config_get(app.state.TOP_K),
 | 
			
		||||
        "r": config_get(app.state.RELEVANCE_THRESHOLD),
 | 
			
		||||
        "hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
 | 
			
		||||
        "template": app.state.config.RAG_TEMPLATE,
 | 
			
		||||
        "k": app.state.config.TOP_K,
 | 
			
		||||
        "r": app.state.config.RELEVANCE_THRESHOLD,
 | 
			
		||||
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -483,17 +463,15 @@ def query_doc_handler(
 | 
			
		||||
    user=Depends(get_current_user),
 | 
			
		||||
):
 | 
			
		||||
    try:
 | 
			
		||||
        if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
 | 
			
		||||
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
 | 
			
		||||
            return query_doc_with_hybrid_search(
 | 
			
		||||
                collection_name=form_data.collection_name,
 | 
			
		||||
                query=form_data.query,
 | 
			
		||||
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                k=form_data.k if form_data.k else config_get(app.state.TOP_K),
 | 
			
		||||
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
			
		||||
                reranking_function=app.state.sentence_transformer_rf,
 | 
			
		||||
                r=(
 | 
			
		||||
                    form_data.r
 | 
			
		||||
                    if form_data.r
 | 
			
		||||
                    else config_get(app.state.RELEVANCE_THRESHOLD)
 | 
			
		||||
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
@ -501,7 +479,7 @@ def query_doc_handler(
 | 
			
		||||
                collection_name=form_data.collection_name,
 | 
			
		||||
                query=form_data.query,
 | 
			
		||||
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                k=form_data.k if form_data.k else config_get(app.state.TOP_K),
 | 
			
		||||
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
			
		||||
            )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(e)
 | 
			
		||||
@ -525,17 +503,15 @@ def query_collection_handler(
 | 
			
		||||
    user=Depends(get_current_user),
 | 
			
		||||
):
 | 
			
		||||
    try:
 | 
			
		||||
        if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
 | 
			
		||||
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
 | 
			
		||||
            return query_collection_with_hybrid_search(
 | 
			
		||||
                collection_names=form_data.collection_names,
 | 
			
		||||
                query=form_data.query,
 | 
			
		||||
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                k=form_data.k if form_data.k else config_get(app.state.TOP_K),
 | 
			
		||||
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
			
		||||
                reranking_function=app.state.sentence_transformer_rf,
 | 
			
		||||
                r=(
 | 
			
		||||
                    form_data.r
 | 
			
		||||
                    if form_data.r
 | 
			
		||||
                    else config_get(app.state.RELEVANCE_THRESHOLD)
 | 
			
		||||
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
@ -543,7 +519,7 @@ def query_collection_handler(
 | 
			
		||||
                collection_names=form_data.collection_names,
 | 
			
		||||
                query=form_data.query,
 | 
			
		||||
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                k=form_data.k if form_data.k else config_get(app.state.TOP_K),
 | 
			
		||||
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
@ -560,8 +536,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
 | 
			
		||||
        loader = YoutubeLoader.from_youtube_url(
 | 
			
		||||
            form_data.url,
 | 
			
		||||
            add_video_info=True,
 | 
			
		||||
            language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
 | 
			
		||||
            translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION),
 | 
			
		||||
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
 | 
			
		||||
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
 | 
			
		||||
        )
 | 
			
		||||
        data = loader.load()
 | 
			
		||||
 | 
			
		||||
@ -589,7 +565,7 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
 | 
			
		||||
    try:
 | 
			
		||||
        loader = get_web_loader(
 | 
			
		||||
            form_data.url,
 | 
			
		||||
            verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION),
 | 
			
		||||
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
			
		||||
        )
 | 
			
		||||
        data = loader.load()
 | 
			
		||||
 | 
			
		||||
@ -645,8 +621,8 @@ def resolve_hostname(hostname):
 | 
			
		||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
 | 
			
		||||
 | 
			
		||||
    text_splitter = RecursiveCharacterTextSplitter(
 | 
			
		||||
        chunk_size=config_get(app.state.CHUNK_SIZE),
 | 
			
		||||
        chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
 | 
			
		||||
        chunk_size=app.state.config.CHUNK_SIZE,
 | 
			
		||||
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
 | 
			
		||||
        add_start_index=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -663,8 +639,8 @@ def store_text_in_vector_db(
 | 
			
		||||
    text, metadata, collection_name, overwrite: bool = False
 | 
			
		||||
) -> bool:
 | 
			
		||||
    text_splitter = RecursiveCharacterTextSplitter(
 | 
			
		||||
        chunk_size=config_get(app.state.CHUNK_SIZE),
 | 
			
		||||
        chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
 | 
			
		||||
        chunk_size=app.state.config.CHUNK_SIZE,
 | 
			
		||||
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
 | 
			
		||||
        add_start_index=True,
 | 
			
		||||
    )
 | 
			
		||||
    docs = text_splitter.create_documents([text], metadatas=[metadata])
 | 
			
		||||
@ -687,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
 | 
			
		||||
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
 | 
			
		||||
 | 
			
		||||
        embedding_func = get_embedding_function(
 | 
			
		||||
            config_get(app.state.RAG_EMBEDDING_ENGINE),
 | 
			
		||||
            config_get(app.state.RAG_EMBEDDING_MODEL),
 | 
			
		||||
            app.state.config.RAG_EMBEDDING_ENGINE,
 | 
			
		||||
            app.state.config.RAG_EMBEDDING_MODEL,
 | 
			
		||||
            app.state.sentence_transformer_ef,
 | 
			
		||||
            config_get(app.state.OPENAI_API_KEY),
 | 
			
		||||
            config_get(app.state.OPENAI_API_BASE_URL),
 | 
			
		||||
            app.state.config.OPENAI_API_KEY,
 | 
			
		||||
            app.state.config.OPENAI_API_BASE_URL,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
 | 
			
		||||
@ -766,7 +742,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
 | 
			
		||||
 | 
			
		||||
    if file_ext == "pdf":
 | 
			
		||||
        loader = PyPDFLoader(
 | 
			
		||||
            file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES)
 | 
			
		||||
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
 | 
			
		||||
        )
 | 
			
		||||
    elif file_ext == "csv":
 | 
			
		||||
        loader = CSVLoader(file_path)
 | 
			
		||||
 | 
			
		||||
@ -22,21 +22,23 @@ from config import (
 | 
			
		||||
    WEBHOOK_URL,
 | 
			
		||||
    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
 | 
			
		||||
    JWT_EXPIRES_IN,
 | 
			
		||||
    config_get,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
origins = ["*"]
 | 
			
		||||
 | 
			
		||||
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
 | 
			
		||||
app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
 | 
			
		||||
app.state.DEFAULT_MODELS = DEFAULT_MODELS
 | 
			
		||||
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
 | 
			
		||||
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
 | 
			
		||||
app.state.USER_PERMISSIONS = USER_PERMISSIONS
 | 
			
		||||
app.state.WEBHOOK_URL = WEBHOOK_URL
 | 
			
		||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
 | 
			
		||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
 | 
			
		||||
 | 
			
		||||
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
 | 
			
		||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
 | 
			
		||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
 | 
			
		||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
 | 
			
		||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
 | 
			
		||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 | 
			
		||||
 | 
			
		||||
app.add_middleware(
 | 
			
		||||
@ -63,6 +65,6 @@ async def get_status():
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "auth": WEBUI_AUTH,
 | 
			
		||||
        "default_models": config_get(app.state.DEFAULT_MODELS),
 | 
			
		||||
        "default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS),
 | 
			
		||||
        "default_models": app.state.config.DEFAULT_MODELS,
 | 
			
		||||
        "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ from utils.utils import (
 | 
			
		||||
from utils.misc import parse_duration, validate_email_format
 | 
			
		||||
from utils.webhook import post_webhook
 | 
			
		||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 | 
			
		||||
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set
 | 
			
		||||
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
 | 
			
		||||
    if user:
 | 
			
		||||
        token = create_token(
 | 
			
		||||
            data={"id": user.id},
 | 
			
		||||
            expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)),
 | 
			
		||||
            expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
 | 
			
		||||
 | 
			
		||||
@router.post("/signup", response_model=SigninResponse)
 | 
			
		||||
async def signup(request: Request, form_data: SignupForm):
 | 
			
		||||
    if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH:
 | 
			
		||||
    if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
 | 
			
		||||
        )
 | 
			
		||||
@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
 | 
			
		||||
        role = (
 | 
			
		||||
            "admin"
 | 
			
		||||
            if Users.get_num_users() == 0
 | 
			
		||||
            else config_get(request.app.state.DEFAULT_USER_ROLE)
 | 
			
		||||
            else request.app.state.config.DEFAULT_USER_ROLE
 | 
			
		||||
        )
 | 
			
		||||
        hashed = get_password_hash(form_data.password)
 | 
			
		||||
        user = Auths.insert_new_auth(
 | 
			
		||||
@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
 | 
			
		||||
        if user:
 | 
			
		||||
            token = create_token(
 | 
			
		||||
                data={"id": user.id},
 | 
			
		||||
                expires_delta=parse_duration(
 | 
			
		||||
                    config_get(request.app.state.JWT_EXPIRES_IN)
 | 
			
		||||
                ),
 | 
			
		||||
                expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
 | 
			
		||||
            )
 | 
			
		||||
            # response.set_cookie(key='token', value=token, httponly=True)
 | 
			
		||||
 | 
			
		||||
            if config_get(request.app.state.WEBHOOK_URL):
 | 
			
		||||
            if request.app.state.config.WEBHOOK_URL:
 | 
			
		||||
                post_webhook(
 | 
			
		||||
                    config_get(request.app.state.WEBHOOK_URL),
 | 
			
		||||
                    request.app.state.config.WEBHOOK_URL,
 | 
			
		||||
                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
 | 
			
		||||
                    {
 | 
			
		||||
                        "action": "signup",
 | 
			
		||||
@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
 | 
			
		||||
 | 
			
		||||
@router.get("/signup/enabled", response_model=bool)
 | 
			
		||||
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return config_get(request.app.state.ENABLE_SIGNUP)
 | 
			
		||||
    return request.app.state.config.ENABLE_SIGNUP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.get("/signup/enabled/toggle", response_model=bool)
 | 
			
		||||
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    config_set(
 | 
			
		||||
        request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP)
 | 
			
		||||
    )
 | 
			
		||||
    return config_get(request.app.state.ENABLE_SIGNUP)
 | 
			
		||||
    request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
 | 
			
		||||
    return request.app.state.config.ENABLE_SIGNUP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
############################
 | 
			
		||||
@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
 | 
			
		||||
@router.get("/signup/user/role")
 | 
			
		||||
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return config_get(request.app.state.DEFAULT_USER_ROLE)
 | 
			
		||||
    return request.app.state.config.DEFAULT_USER_ROLE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UpdateRoleForm(BaseModel):
 | 
			
		||||
@ -308,8 +304,8 @@ async def update_default_user_role(
 | 
			
		||||
    request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    if form_data.role in ["pending", "user", "admin"]:
 | 
			
		||||
        config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role)
 | 
			
		||||
    return config_get(request.app.state.DEFAULT_USER_ROLE)
 | 
			
		||||
        request.app.state.config.DEFAULT_USER_ROLE = form_data.role
 | 
			
		||||
    return request.app.state.config.DEFAULT_USER_ROLE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
############################
 | 
			
		||||
@ -319,7 +315,7 @@ async def update_default_user_role(
 | 
			
		||||
 | 
			
		||||
@router.get("/token/expires")
 | 
			
		||||
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return config_get(request.app.state.JWT_EXPIRES_IN)
 | 
			
		||||
    return request.app.state.config.JWT_EXPIRES_IN
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UpdateJWTExpiresDurationForm(BaseModel):
 | 
			
		||||
@ -336,10 +332,10 @@ async def update_token_expires_duration(
 | 
			
		||||
 | 
			
		||||
    # Check if the input string matches the pattern
 | 
			
		||||
    if re.match(pattern, form_data.duration):
 | 
			
		||||
        config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration)
 | 
			
		||||
        return config_get(request.app.state.JWT_EXPIRES_IN)
 | 
			
		||||
        request.app.state.config.JWT_EXPIRES_IN = form_data.duration
 | 
			
		||||
        return request.app.state.config.JWT_EXPIRES_IN
 | 
			
		||||
    else:
 | 
			
		||||
        return config_get(request.app.state.JWT_EXPIRES_IN)
 | 
			
		||||
        return request.app.state.config.JWT_EXPIRES_IN
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
############################
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@ import time
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
from apps.web.models.users import Users
 | 
			
		||||
from config import config_set, config_get
 | 
			
		||||
 | 
			
		||||
from utils.utils import (
 | 
			
		||||
    get_password_hash,
 | 
			
		||||
@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
 | 
			
		||||
async def set_global_default_models(
 | 
			
		||||
    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    config_set(request.app.state.DEFAULT_MODELS, form_data.models)
 | 
			
		||||
    return config_get(request.app.state.DEFAULT_MODELS)
 | 
			
		||||
    request.app.state.config.DEFAULT_MODELS = form_data.models
 | 
			
		||||
    return request.app.state.config.DEFAULT_MODELS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
 | 
			
		||||
@ -56,5 +55,5 @@ async def set_global_default_suggestions(
 | 
			
		||||
    user=Depends(get_admin_user),
 | 
			
		||||
):
 | 
			
		||||
    data = form_data.model_dump()
 | 
			
		||||
    config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"])
 | 
			
		||||
    return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS)
 | 
			
		||||
    request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
 | 
			
		||||
    return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ from apps.web.models.auths import Auths
 | 
			
		||||
from utils.utils import get_current_user, get_password_hash, get_admin_user
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
 | 
			
		||||
from config import SRC_LOG_LEVELS, config_set, config_get
 | 
			
		||||
from config import SRC_LOG_LEVELS
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
 | 
			
		||||
@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
 | 
			
		||||
 | 
			
		||||
@router.get("/permissions/user")
 | 
			
		||||
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return config_get(request.app.state.USER_PERMISSIONS)
 | 
			
		||||
    return request.app.state.config.USER_PERMISSIONS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/permissions/user")
 | 
			
		||||
async def update_user_permissions(
 | 
			
		||||
    request: Request, form_data: dict, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    config_set(request.app.state.USER_PERMISSIONS, form_data)
 | 
			
		||||
    return config_get(request.app.state.USER_PERMISSIONS)
 | 
			
		||||
    request.app.state.config.USER_PERMISSIONS = form_data
 | 
			
		||||
    return request.app.state.config.USER_PERMISSIONS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
############################
 | 
			
		||||
 | 
			
		||||
@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]):
 | 
			
		||||
        self.config_value = self.value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True):
 | 
			
		||||
    if isinstance(config, WrappedConfig):
 | 
			
		||||
        config.value = value
 | 
			
		||||
        if save_config:
 | 
			
		||||
            config.save()
 | 
			
		||||
    else:
 | 
			
		||||
        config = value
 | 
			
		||||
class AppConfig:
 | 
			
		||||
    _state: dict[str, WrappedConfig]
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__setattr__("_state", {})
 | 
			
		||||
 | 
			
		||||
def config_get(config: Union[WrappedConfig[T], T]) -> T:
 | 
			
		||||
    if isinstance(config, WrappedConfig):
 | 
			
		||||
        return config.value
 | 
			
		||||
    return config
 | 
			
		||||
    def __setattr__(self, key, value):
 | 
			
		||||
        if isinstance(value, WrappedConfig):
 | 
			
		||||
            self._state[key] = value
 | 
			
		||||
        else:
 | 
			
		||||
            self._state[key].value = value
 | 
			
		||||
            self._state[key].save()
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        return self._state[key].value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
####################################
 | 
			
		||||
 | 
			
		||||
@ -58,8 +58,7 @@ from config import (
 | 
			
		||||
    SRC_LOG_LEVELS,
 | 
			
		||||
    WEBHOOK_URL,
 | 
			
		||||
    ENABLE_ADMIN_EXPORT,
 | 
			
		||||
    config_get,
 | 
			
		||||
    config_set,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
 | 
			
		||||
@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui
 | 
			
		||||
 | 
			
		||||
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 | 
			
		||||
 | 
			
		||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 | 
			
		||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 | 
			
		||||
app.state.config = AppConfig()
 | 
			
		||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 | 
			
		||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
app.state.WEBHOOK_URL = WEBHOOK_URL
 | 
			
		||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
 | 
			
		||||
 | 
			
		||||
origins = ["*"]
 | 
			
		||||
 | 
			
		||||
@ -245,11 +245,9 @@ async def get_app_config():
 | 
			
		||||
        "version": VERSION,
 | 
			
		||||
        "auth": WEBUI_AUTH,
 | 
			
		||||
        "default_locale": default_locale,
 | 
			
		||||
        "images": config_get(images_app.state.ENABLED),
 | 
			
		||||
        "default_models": config_get(webui_app.state.DEFAULT_MODELS),
 | 
			
		||||
        "default_prompt_suggestions": config_get(
 | 
			
		||||
            webui_app.state.DEFAULT_PROMPT_SUGGESTIONS
 | 
			
		||||
        ),
 | 
			
		||||
        "images": images_app.state.config.ENABLED,
 | 
			
		||||
        "default_models": webui_app.state.config.DEFAULT_MODELS,
 | 
			
		||||
        "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
 | 
			
		||||
        "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
 | 
			
		||||
        "admin_export_enabled": ENABLE_ADMIN_EXPORT,
 | 
			
		||||
    }
 | 
			
		||||
@ -258,8 +256,8 @@ async def get_app_config():
 | 
			
		||||
@app.get("/api/config/model/filter")
 | 
			
		||||
async def get_model_filter_config(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "enabled": config_get(app.state.ENABLE_MODEL_FILTER),
 | 
			
		||||
        "models": config_get(app.state.MODEL_FILTER_LIST),
 | 
			
		||||
        "enabled": app.state.config.ENABLE_MODEL_FILTER,
 | 
			
		||||
        "models": app.state.config.MODEL_FILTER_LIST,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel):
 | 
			
		||||
async def update_model_filter_config(
 | 
			
		||||
    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled)
 | 
			
		||||
    config_set(app.state.MODEL_FILTER_LIST, form_data.models)
 | 
			
		||||
    app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
 | 
			
		||||
    app.state.config.MODEL_FILTER_LIST, form_data.models
 | 
			
		||||
 | 
			
		||||
    ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
 | 
			
		||||
    ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
 | 
			
		||||
    ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
 | 
			
		||||
    ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
    openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
 | 
			
		||||
    openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
 | 
			
		||||
    openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
 | 
			
		||||
    openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
    litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER)
 | 
			
		||||
    litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST)
 | 
			
		||||
    litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
 | 
			
		||||
    litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "enabled": config_get(app.state.ENABLE_MODEL_FILTER),
 | 
			
		||||
        "models": config_get(app.state.MODEL_FILTER_LIST),
 | 
			
		||||
        "enabled": app.state.config.ENABLE_MODEL_FILTER,
 | 
			
		||||
        "models": app.state.config.MODEL_FILTER_LIST,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/api/webhook")
 | 
			
		||||
async def get_webhook_url(user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "url": config_get(app.state.WEBHOOK_URL),
 | 
			
		||||
        "url": app.state.config.WEBHOOK_URL,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -303,12 +301,12 @@ class UrlForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@app.post("/api/webhook")
 | 
			
		||||
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
 | 
			
		||||
    config_set(app.state.WEBHOOK_URL, form_data.url)
 | 
			
		||||
    app.state.config.WEBHOOK_URL = form_data.url
 | 
			
		||||
 | 
			
		||||
    webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL)
 | 
			
		||||
    webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "url": config_get(app.state.WEBHOOK_URL),
 | 
			
		||||
        "url": app.state.config.WEBHOOK_URL,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user