diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index df268067f..fb8a35a17 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -43,6 +43,7 @@ from utils.utils import ( from config import ( SRC_LOG_LEVELS, OLLAMA_BASE_URLS, + ENABLE_OLLAMA_API, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, @@ -67,6 +68,8 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST + +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -96,6 +99,21 @@ async def get_status(): return {"status": True} +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + + +class OllamaConfigForm(BaseModel): + enable_ollama_api: Optional[bool] = None + + +@app.post("/config/update") +async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): + app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api + return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + + @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} @@ -156,14 +174,23 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS] - responses = await asyncio.gather(*tasks) - models = { - "models": merge_models_lists( - map(lambda response: response["models"] if response else None, responses) - ) - } + if app.state.config.ENABLE_OLLAMA_API: + tasks = [ + fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS + ] + responses = await asyncio.gather(*tasks) + + models = { + "models": merge_models_lists( + map( + lambda response: response["models"] if response else None, responses + ) + ) + } + + else: + models = {"models": []} app.state.MODELS = {model["model"]: model for model in models["models"]} diff --git a/backend/config.py b/backend/config.py index 9ba059008..0b18eab43 100644 --- a/backend/config.py +++ b/backend/config.py @@ -384,6 +384,13 @@ if not os.path.exists(LITELLM_CONFIG_PATH): # OLLAMA_BASE_URL #################################### + +ENABLE_OLLAMA_API = PersistentConfig( + "ENABLE_OLLAMA_API", + "ollama.enable", + os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true", +) + OLLAMA_API_BASE_URL = os.environ.get( "OLLAMA_API_BASE_URL", "http://localhost:11434/api" ) diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 7ecd65efe..b7f842177 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,6 +1,73 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; import { promptTemplate } from '$lib/utils'; +export const getOllamaConfig = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/config`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateOllamaConfig = async (token: string = '', enable_ollama_api: boolean) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + enable_ollama_api: enable_ollama_api + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOllamaUrls = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Settings/Connections.svelte b/src/lib/components/chat/Settings/Connections.svelte index 01a47efbc..b9978e129 100644 --- a/src/lib/components/chat/Settings/Connections.svelte +++ b/src/lib/components/chat/Settings/Connections.svelte @@ -3,7 +3,13 @@ import { createEventDispatcher, onMount, getContext } from 'svelte'; const dispatch = createEventDispatcher(); - import { getOllamaUrls, getOllamaVersion, updateOllamaUrls } from '$lib/apis/ollama'; + import { + getOllamaConfig, + getOllamaUrls, + getOllamaVersion, + updateOllamaConfig, + updateOllamaUrls + } from '$lib/apis/ollama'; import { getOpenAIConfig, getOpenAIKeys, @@ -26,6 +32,7 @@ let OPENAI_API_BASE_URLS = ['']; let ENABLE_OPENAI_API = false; + let ENABLE_OLLAMA_API = false; const updateOpenAIHandler = async () => { OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS); @@ -50,10 +57,13 @@ onMount(async () => { if ($user.role === 'admin') { - OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token); + const ollamaConfig = await getOllamaConfig(localStorage.token); + const openaiConfig = await getOpenAIConfig(localStorage.token); - const config = await getOpenAIConfig(localStorage.token); - ENABLE_OPENAI_API = config.ENABLE_OPENAI_API; + ENABLE_OPENAI_API = openaiConfig.ENABLE_OPENAI_API; + ENABLE_OLLAMA_API = ollamaConfig.ENABLE_OLLAMA_API; + + OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token); OPENAI_API_BASE_URLS = await getOpenAIUrls(localStorage.token); OPENAI_API_KEYS = await getOpenAIKeys(localStorage.token); @@ -161,95 +171,108 @@