From d077b3dcdb92e9ebb8cec5952e56af5f69707881 Mon Sep 17 00:00:00 2001
From: "Timothy J. Baek" <timothyjrbeck@gmail.com>
Date: Wed, 5 Jun 2024 09:08:52 -0700
Subject: [PATCH] fix: ollama version request when ollama api is disabled

---
 backend/apps/ollama/main.py | 89 +++++++++++++++++++------------------
 backend/constants.py        |  4 ++
 2 files changed, 50 insertions(+), 43 deletions(-)

diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py
index 35564f0c1..82cd8d383 100644
--- a/backend/apps/ollama/main.py
+++ b/backend/apps/ollama/main.py
@@ -274,54 +274,57 @@ async def get_ollama_tags(
 @app.get("/api/version")
 @app.get("/api/version/{url_idx}")
 async def get_ollama_versions(url_idx: Optional[int] = None):
+    if app.state.config.ENABLE_OLLAMA_API:
+        if url_idx == None:
 
-    if url_idx == None:
+            # returns lowest version
+            tasks = [
+                fetch_url(f"{url}/api/version")
+                for url in app.state.config.OLLAMA_BASE_URLS
+            ]
+            responses = await asyncio.gather(*tasks)
+            responses = list(filter(lambda x: x is not None, responses))
 
-        # returns lowest version
-        tasks = [
-            fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
-        ]
-        responses = await asyncio.gather(*tasks)
-        responses = list(filter(lambda x: x is not None, responses))
+            if len(responses) > 0:
+                lowest_version = min(
+                    responses,
+                    key=lambda x: tuple(
+                        map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
+                    ),
+                )
 
-        if len(responses) > 0:
-            lowest_version = min(
-                responses,
-                key=lambda x: tuple(
-                    map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
-                ),
-            )
-
-            return {"version": lowest_version["version"]}
+                return {"version": lowest_version["version"]}
+            else:
+                raise HTTPException(
+                    status_code=500,
+                    detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
+                )
         else:
-            raise HTTPException(
-                status_code=500,
-                detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
-            )
+            url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+
+            r = None
+            try:
+                r = requests.request(method="GET", url=f"{url}/api/version")
+                r.raise_for_status()
+
+                return r.json()
+            except Exception as e:
+                log.exception(e)
+                error_detail = "Open WebUI: Server Connection Error"
+                if r is not None:
+                    try:
+                        res = r.json()
+                        if "error" in res:
+                            error_detail = f"Ollama: {res['error']}"
+                    except:
+                        error_detail = f"Ollama: {e}"
+
+                raise HTTPException(
+                    status_code=r.status_code if r else 500,
+                    detail=error_detail,
+                )
     else:
-        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-        r = None
-        try:
-            r = requests.request(method="GET", url=f"{url}/api/version")
-            r.raise_for_status()
-
-            return r.json()
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"Ollama: {res['error']}"
-                except:
-                    error_detail = f"Ollama: {e}"
-
-            raise HTTPException(
-                status_code=r.status_code if r else 500,
-                detail=error_detail,
-            )
+        return {"version": False}
 
 
 class ModelNameForm(BaseModel):
diff --git a/backend/constants.py b/backend/constants.py
index 72ca0c413..0740fa49d 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -84,3 +84,7 @@ class ERROR_MESSAGES(str, Enum):
     WEB_SEARCH_ERROR = (
         lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
     )
+
+    OLLAMA_API_DISABLED = (
+        "The Ollama API is disabled. Please enable it to use this feature."
+    )