diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index b44f68017..0ac1f0401 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -195,7 +195,10 @@ async def post_streaming_url( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -210,13 +213,13 @@ async def post_streaming_url( r.raise_for_status() if stream: - headers = dict(r.headers) + response_headers = dict(r.headers) if content_type: - headers["Content-Type"] = content_type + response_headers["Content-Type"] = content_type return StreamingResponse( r.content, status_code=r.status, - headers=headers, + headers=response_headers, background=BackgroundTask( cleanup_response, response=r, session=session ), @@ -324,7 +327,10 @@ async def get_ollama_tags( else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {} @@ -525,7 +531,10 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -584,7 +593,10 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -635,7 +647,10 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -730,7 +745,10 @@ async def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -797,7 +815,10 @@ async def generate_ollama_batch_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -974,7 +995,10 @@ async def generate_chat_completion( log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "")