From 99446c4b7615bc072fcbaf1ce0bcd5fd7c843132 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 11 Nov 2024 22:33:18 -0800 Subject: [PATCH] feat: ollama auth support --- backend/open_webui/apps/ollama/main.py | 73 ++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 67268a95b..9e4a8a451 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -209,10 +209,18 @@ async def post_streaming_url( session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = await session.post( url, data=payload, - headers={"Content-Type": "application/json"}, + headers=headers, ) r.raise_for_status() @@ -275,9 +283,10 @@ async def get_all_models(): else: api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) + key = api_config.get("key", None) if enable: - tasks.append(aiohttp_get(f"{url}/api/tags")) + tasks.append(aiohttp_get(f"{url}/api/tags", key)) else: tasks.append(None) @@ -341,9 +350,16 @@ async def get_ollama_tags( else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {} + if key: + headers["Authorization"] = f"Bearer {key}" + r = None try: - r = requests.request(method="GET", url=f"{url}/api/tags") + r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) r.raise_for_status() return r.json() @@ -371,7 +387,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None): if url_idx is None: # returns lowest version tasks = [ - aiohttp_get(f"{url}/api/version") + aiohttp_get( + f"{url}/api/version", + app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + ) for url in app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) @@ -511,10 +530,18 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/copy", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -560,11 +587,18 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="DELETE", url=f"{url}/api/delete", - headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), + headers=headers, ) try: r.raise_for_status() @@ -601,10 +635,17 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/show", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -686,10 +727,17 @@ def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/embeddings", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: @@ -743,10 +791,17 @@ def generate_ollama_batch_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + key = api_config.get("key", None) + + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + r = requests.request( method="POST", url=f"{url}/api/embed", - headers={"Content-Type": "application/json"}, + headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: