From 8dc73e87440c92fd8ab41052a3b3cbfd46593911 Mon Sep 17 00:00:00 2001 From: bnodnarb <97063458+bnodnarb@users.noreply.github.com> Date: Sun, 24 Nov 2024 20:29:54 -1000 Subject: [PATCH] Fix: Add authorization header with bearer token for remote Ollama server endpoints --- backend/open_webui/apps/ollama/main.py | 46 ++++++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index b44f68017..0ac1f0401 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -195,7 +195,10 @@ async def post_streaming_url( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -210,13 +213,13 @@ async def post_streaming_url( r.raise_for_status() if stream: - headers = dict(r.headers) + response_headers = dict(r.headers) if content_type: - headers["Content-Type"] = content_type + response_headers["Content-Type"] = content_type return StreamingResponse( r.content, status_code=r.status, - headers=headers, + headers=response_headers, background=BackgroundTask( cleanup_response, response=r, session=session ), @@ -324,7 +327,10 @@ async def get_ollama_tags( else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {} @@ -525,7 +531,10 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -584,7 +593,10 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -635,7 +647,10 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -730,7 +745,10 @@ async def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -797,7 +815,10 @@ async def generate_ollama_batch_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -974,7 +995,10 @@ async def generate_chat_completion( log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "")