openai embeddings function modified

This commit is contained in:
henry 2025-06-04 16:13:53 +02:00
parent 8f6c3f46d6
commit 3ddebefca2
2 changed files with 52 additions and 20 deletions

View File

@ -411,6 +411,7 @@ from open_webui.utils.chat import (
chat_completed as chat_completed_handler, chat_completed as chat_completed_handler,
chat_action as chat_action_handler, chat_action as chat_action_handler,
) )
from open_webui.utils.embeddings import generate_embeddings
from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.middleware import process_chat_payload, process_chat_response
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
@ -1363,11 +1364,6 @@ async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified
return {"task_ids": task_ids} return {"task_ids": task_ids}
@app.post("/api/embeddings")
async def api_embeddings(request: Request, user=Depends(get_verified_user)):
return await openai.generate_embeddings(request=request, user=user)
################################## ##################################
# #
# Config Endpoints # Config Endpoints
@ -1544,6 +1540,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
async def get_app_changelog(): async def get_app_changelog():
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
##################################
# Embeddings
##################################
@app.post("/api/embeddings")
async def embeddings_endpoint(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""
OpenAI-compatible embeddings endpoint.
This handler:
- Performs user/model checks and dispatches to the correct backend.
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
Args:
request (Request): Request context.
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
user (UserModel): Authenticated user.
Returns:
dict: OpenAI-compatible embeddings response.
"""
# Make sure models are loaded in app state
if not request.app.state.MODELS:
await get_all_models(request, user=user)
# Use generic dispatcher in utils.embeddings
return await generate_embeddings(request, form_data, user)
############################ ############################
# OAuth Login & Callback # OAuth Login & Callback

View File

@ -886,26 +886,36 @@ async def generate_chat_completion(
r.close() r.close()
await session.close() await session.close()
@router.post("/embeddings") async def embeddings(request: Request, form_data: dict, user):
async def generate_embeddings(request: Request, user=Depends(get_verified_user)):
""" """
Call embeddings endpoint Calls the embeddings endpoint for OpenAI-compatible providers.
Args:
request (Request): The FastAPI request context.
form_data (dict): OpenAI-compatible embeddings payload.
user (UserModel): The authenticated user.
Returns:
dict: OpenAI-compatible embeddings response.
""" """
body = await request.body()
idx = 0 idx = 0
# Prepare payload/body
body = json.dumps(form_data)
# Find correct backend url/key based on model
await get_all_models(request, user=user)
model_id = form_data.get("model")
models = request.app.state.OPENAI_MODELS
if model_id in models:
idx = models[model_id]["urlIdx"]
url = request.app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx]
r = None r = None
session = None session = None
streaming = False streaming = False
try: try:
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(trust_env=True)
r = await session.request( r = await session.request(
method=request.method, method="POST",
url=f"{url}/embeddings", url=f"{url}/embeddings",
data=body, data=body,
headers={ headers={
@ -918,14 +928,11 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
} }
if ENABLE_FORWARD_USER_INFO_HEADERS if ENABLE_FORWARD_USER_INFO_HEADERS and user else {}
else {}
), ),
}, },
) )
r.raise_for_status() r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True streaming = True
return StreamingResponse( return StreamingResponse(
@ -939,10 +946,8 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
else: else:
response_data = await r.json() response_data = await r.json()
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
detail = None detail = None
if r is not None: if r is not None:
try: try: