From cdbabdfa5a26d78b2e4728012ca4e01e4df1b5b2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 17 May 2024 10:30:22 -0700 Subject: [PATCH] refac --- backend/apps/openai/main.py | 31 +++++++-- backend/config.py | 8 +++ src/lib/apis/openai/index.ts | 67 +++++++++++++++++++ src/lib/components/chat/MessageInput.svelte | 4 +- .../components/chat/Messages/CodeBlock.svelte | 5 +- .../chat/Settings/Connections.svelte | 31 +++++---- 6 files changed, 125 insertions(+), 21 deletions(-) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 65ed25f1c..fa4237a67 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -21,6 +21,7 @@ from utils.utils import ( ) from config import ( SRC_LOG_LEVELS, + ENABLE_OPENAI_API, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR, @@ -51,6 +52,8 @@ app.state.config = AppConfig() app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -68,6 +71,21 @@ async def check_url(request: Request, call_next): return response +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + + +class OpenAIConfigForm(BaseModel): + enable_openai_api: Optional[bool] = None + + +@app.post("/config/update") +async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): + app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api + return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + + class UrlsUpdateForm(BaseModel): urls: List[str] @@ -165,10 +183,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): async def fetch_url(url, key): try: - headers = {"Authorization": f"Bearer {key}"} - async with aiohttp.ClientSession() as session: - async with session.get(url, headers=headers) as response: - return await response.json() + if key != "": + headers = {"Authorization": f"Bearer {key}"} + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + return await response.json() + else: + return None except Exception as e: # Handle connection error here log.error(f"Connection error: {e}") @@ -200,7 +221,7 @@ async def get_all_models(): if ( len(app.state.config.OPENAI_API_KEYS) == 1 and app.state.config.OPENAI_API_KEYS[0] == "" - ): + ) or not app.state.config.ENABLE_OPENAI_API: models = {"data": []} else: tasks = [ diff --git a/backend/config.py b/backend/config.py index 112edba90..1a62e98bf 100644 --- a/backend/config.py +++ b/backend/config.py @@ -417,6 +417,14 @@ OLLAMA_BASE_URLS = PersistentConfig( # OPENAI_API #################################### + +ENABLE_OPENAI_API = PersistentConfig( + "ENABLE_OPENAI_API", + "openai.enable", + os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true", +) + + OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 41b6f9b6d..02281eff0 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -1,6 +1,73 @@ import { OPENAI_API_BASE_URL } from '$lib/constants'; import { promptTemplate } from '$lib/utils'; +export const getOpenAIConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/config`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateOpenAIConfig = async (token: string = '', enable_openai_api: boolean) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + enable_openai_api: enable_openai_api + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOpenAIUrls = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 106510204..a8a091c55 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -584,7 +584,7 @@ }} />
{ submitPrompt(prompt, user); }} @@ -754,7 +754,7 @@