mirror of
https://github.com/open-webui/open-webui
synced 2025-06-10 00:17:52 +00:00
openai embeddings function modified
This commit is contained in:
parent
8f6c3f46d6
commit
3ddebefca2
@ -411,6 +411,7 @@ from open_webui.utils.chat import (
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
@ -1363,11 +1364,6 @@ async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified
|
||||
return {"task_ids": task_ids}
|
||||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
async def api_embeddings(request: Request, user=Depends(get_verified_user)):
|
||||
return await openai.generate_embeddings(request=request, user=user)
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Config Endpoints
|
||||
@ -1544,6 +1540,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
|
||||
async def get_app_changelog():
|
||||
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
|
||||
|
||||
##################################
|
||||
# Embeddings
|
||||
##################################
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
async def embeddings_endpoint(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
OpenAI-compatible embeddings endpoint.
|
||||
|
||||
This handler:
|
||||
- Performs user/model checks and dispatches to the correct backend.
|
||||
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
|
||||
|
||||
Args:
|
||||
request (Request): Request context.
|
||||
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
|
||||
user (UserModel): Authenticated user.
|
||||
|
||||
Returns:
|
||||
dict: OpenAI-compatible embeddings response.
|
||||
"""
|
||||
# Make sure models are loaded in app state
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
# Use generic dispatcher in utils.embeddings
|
||||
return await generate_embeddings(request, form_data, user)
|
||||
|
||||
|
||||
############################
|
||||
# OAuth Login & Callback
|
||||
|
@ -886,26 +886,36 @@ async def generate_chat_completion(
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
@router.post("/embeddings")
|
||||
async def generate_embeddings(request: Request, user=Depends(get_verified_user)):
|
||||
async def embeddings(request: Request, form_data: dict, user):
|
||||
"""
|
||||
Call embeddings endpoint
|
||||
Calls the embeddings endpoint for OpenAI-compatible providers.
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request context.
|
||||
form_data (dict): OpenAI-compatible embeddings payload.
|
||||
user (UserModel): The authenticated user.
|
||||
|
||||
Returns:
|
||||
dict: OpenAI-compatible embeddings response.
|
||||
"""
|
||||
|
||||
body = await request.body()
|
||||
|
||||
idx = 0
|
||||
# Prepare payload/body
|
||||
body = json.dumps(form_data)
|
||||
# Find correct backend url/key based on model
|
||||
await get_all_models(request, user=user)
|
||||
model_id = form_data.get("model")
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if model_id in models:
|
||||
idx = models[model_id]["urlIdx"]
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
method="POST",
|
||||
url=f"{url}/embeddings",
|
||||
data=body,
|
||||
headers={
|
||||
@ -918,14 +928,11 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
@ -939,10 +946,8 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
|
||||
else:
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user