diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a629478e4..b1a75a9cb 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -411,6 +411,7 @@ from open_webui.utils.chat import ( chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) +from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access @@ -1363,11 +1364,6 @@ async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified return {"task_ids": task_ids} -@app.post("/api/embeddings") -async def api_embeddings(request: Request, user=Depends(get_verified_user)): - return await openai.generate_embeddings(request=request, user=user) - - ################################## # # Config Endpoints @@ -1544,6 +1540,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)): async def get_app_changelog(): return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} +################################## +# Embeddings +################################## + +@app.post("/api/embeddings") +async def embeddings_endpoint( + request: Request, + form_data: dict, + user=Depends(get_verified_user) +): + """ + OpenAI-compatible embeddings endpoint. + + This handler: + - Performs user/model checks and dispatches to the correct backend. + - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider. + + Args: + request (Request): Request context. + form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]}) + user (UserModel): Authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + # Make sure models are loaded in app state + if not request.app.state.MODELS: + await get_all_models(request, user=user) + # Use generic dispatcher in utils.embeddings + return await generate_embeddings(request, form_data, user) + ############################ # OAuth Login & Callback diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 5343c3e7a..486246b64 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -886,26 +886,36 @@ async def generate_chat_completion( r.close() await session.close() -@router.post("/embeddings") -async def generate_embeddings(request: Request, user=Depends(get_verified_user)): +async def embeddings(request: Request, form_data: dict, user): """ - Call embeddings endpoint + Calls the embeddings endpoint for OpenAI-compatible providers. + + Args: + request (Request): The FastAPI request context. + form_data (dict): OpenAI-compatible embeddings payload. + user (UserModel): The authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. """ - - body = await request.body() - idx = 0 + # Prepare payload/body + body = json.dumps(form_data) + # Find correct backend url/key based on model + await get_all_models(request, user=user) + model_id = form_data.get("model") + models = request.app.state.OPENAI_MODELS + if model_id in models: + idx = models[model_id]["urlIdx"] url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] - r = None session = None streaming = False - try: session = aiohttp.ClientSession(trust_env=True) r = await session.request( - method=request.method, + method="POST", url=f"{url}/embeddings", data=body, headers={ @@ -918,14 +928,11 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user)) "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} + if ENABLE_FORWARD_USER_INFO_HEADERS and user else {} ), }, ) r.raise_for_status() - - # Check if response is SSE if "text/event-stream" in r.headers.get("Content-Type", ""): streaming = True return StreamingResponse( @@ -939,10 +946,8 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user)) else: response_data = await r.json() return response_data - except Exception as e: log.exception(e) - detail = None if r is not None: try: