diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 6db426439..6a355038b 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file: litellm_config = yaml.safe_load(file) +app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value + + app.state.ENABLE = ENABLE_LITELLM app.state.CONFIG = litellm_config @@ -151,10 +155,6 @@ async def shutdown_litellm_background(): background_process = None -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST - - @app.get("/") async def get_status(): return {"status": True} diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index cb80eeed2..df268067f 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -64,8 +64,8 @@ app.add_middleware( app.state.config = AppConfig() -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -124,8 +124,9 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)) async def fetch_url(url): + timeout = aiohttp.ClientTimeout(total=5) try: - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get(url) as response: return await response.json() except Exception as e: @@ -177,11 +178,12 @@ async def get_ollama_tags( if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + lambda model: model["name"] + in app.state.config.MODEL_FILTER_LIST, models["models"], ) ) @@ -1045,11 +1047,12 @@ async def get_openai_models( if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + lambda model: model["name"] + in app.state.config.MODEL_FILTER_LIST, models["models"], ) ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 65ed25f1c..85ee531f1 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -21,6 +21,7 @@ from utils.utils import ( ) from config import ( SRC_LOG_LEVELS, + ENABLE_OPENAI_API, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR, @@ -46,11 +47,14 @@ app.add_middleware( allow_headers=["*"], ) + app.state.config = AppConfig() -app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER -app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER +app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -68,6 +72,21 @@ async def check_url(request: Request, call_next): return response +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + + +class OpenAIConfigForm(BaseModel): + enable_openai_api: Optional[bool] = None + + +@app.post("/config/update") +async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): + app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api + return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + + class UrlsUpdateForm(BaseModel): urls: List[str] @@ -164,11 +183,15 @@ async def speech(request: Request, user=Depends(get_verified_user)): async def fetch_url(url, key): + timeout = aiohttp.ClientTimeout(total=5) try: - headers = {"Authorization": f"Bearer {key}"} - async with aiohttp.ClientSession() as session: - async with session.get(url, headers=headers) as response: - return await response.json() + if key != "": + headers = {"Authorization": f"Bearer {key}"} + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as response: + return await response.json() + else: + return None except Exception as e: # Handle connection error here log.error(f"Connection error: {e}") @@ -200,7 +223,7 @@ async def get_all_models(): if ( len(app.state.config.OPENAI_API_KEYS) == 1 and app.state.config.OPENAI_API_KEYS[0] == "" - ): + ) or not app.state.config.ENABLE_OPENAI_API: models = {"data": []} else: tasks = [ @@ -237,11 +260,11 @@ async def get_all_models(): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: models = await get_all_models() - if app.state.ENABLE_MODEL_FILTER: + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, models["data"], ) ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index d2c3964ae..ba25f34f6 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -433,12 +433,12 @@ async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): app.state.config.RAG_TEMPLATE = ( - form_data.template if form_data.template else RAG_TEMPLATE, + form_data.template if form_data.template else RAG_TEMPLATE ) app.state.config.TOP_K = form_data.k if form_data.k else 4 app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( - form_data.hybrid if form_data.hybrid else False, + form_data.hybrid if form_data.hybrid else False ) return { "status": True, diff --git a/backend/config.py b/backend/config.py index 112edba90..1a62e98bf 100644 --- a/backend/config.py +++ b/backend/config.py @@ -417,6 +417,14 @@ OLLAMA_BASE_URLS = PersistentConfig( # OPENAI_API #################################### + +ENABLE_OPENAI_API = PersistentConfig( + "ENABLE_OPENAI_API", + "openai.enable", + os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true", +) + + OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") diff --git a/backend/main.py b/backend/main.py index 3d1ed6c2d..209199591 100644 --- a/backend/main.py +++ b/backend/main.py @@ -118,15 +118,15 @@ origins = ["*"] # Custom middleware to add security headers -class SecurityHeadersMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - response: Response = await call_next(request) - response.headers["Cross-Origin-Opener-Policy"] = "same-origin" - response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" - return response +# class SecurityHeadersMiddleware(BaseHTTPMiddleware): +# async def dispatch(self, request: Request, call_next): +# response: Response = await call_next(request) +# response.headers["Cross-Origin-Opener-Policy"] = "same-origin" +# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" +# return response -app.add_middleware(SecurityHeadersMiddleware) +# app.add_middleware(SecurityHeadersMiddleware) class RAGMiddleware(BaseHTTPMiddleware): @@ -289,14 +289,14 @@ class ModelFilterConfigForm(BaseModel): async def update_model_filter_config( form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): - app.state.config.ENABLE_MODEL_FILTER, form_data.enabled - app.state.config.MODEL_FILTER_LIST, form_data.models + app.state.config.ENABLE_MODEL_FILTER = form_data.enabled + app.state.config.MODEL_FILTER_LIST = form_data.models - ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST + ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST - openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER - openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST + openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER + openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 41b6f9b6d..02281eff0 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -1,6 +1,73 @@ import { OPENAI_API_BASE_URL } from '$lib/constants'; import { promptTemplate } from '$lib/utils'; +export const getOpenAIConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/config`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateOpenAIConfig = async (token: string = '', enable_openai_api: boolean) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + enable_openai_api: enable_openai_api + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOpenAIUrls = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 7f4bc36a9..f8050030a 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -585,7 +585,7 @@ />
{ submitPrompt(prompt, user); }} @@ -755,7 +755,7 @@