diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index fe36010b7..2556b4d03 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -545,6 +545,55 @@ class GenerateEmbeddingsForm(BaseModel): @app.post("/api/embed") @app.post("/api/embed/{url_idx}") +async def generate_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): + if url_idx is None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={"Content-Type": "application/json"}, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + try: + r.raise_for_status() + + return r.json() + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except Exception: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.post("/api/embeddings") @app.post("/api/embeddings/{url_idx}") async def generate_embeddings( @@ -571,7 +620,7 @@ async def generate_embeddings( r = requests.request( method="POST", - url=f"{url}/api/embed", + url=f"{url}/api/embeddings", headers={"Content-Type": "application/json"}, data=form_data.model_dump_json(exclude_none=True).encode(), )