mirror of
https://github.com/open-webui/open-webui
synced 2025-06-14 18:33:15 +00:00
refac: embeddings endpoint
This commit is contained in:
parent
b02a3da4da
commit
ab36b8aeae
@ -1208,6 +1208,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
|||||||
return {"data": models}
|
return {"data": models}
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
# Embeddings
|
||||||
|
##################################
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/embeddings")
|
||||||
|
async def embeddings(
|
||||||
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible embeddings endpoint.
|
||||||
|
|
||||||
|
This handler:
|
||||||
|
- Performs user/model checks and dispatches to the correct backend.
|
||||||
|
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): Request context.
|
||||||
|
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
|
||||||
|
user (UserModel): Authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: OpenAI-compatible embeddings response.
|
||||||
|
"""
|
||||||
|
# Make sure models are loaded in app state
|
||||||
|
if not request.app.state.MODELS:
|
||||||
|
await get_all_models(request, user=user)
|
||||||
|
# Use generic dispatcher in utils.embeddings
|
||||||
|
return await generate_embeddings(request, form_data, user)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completions")
|
@app.post("/api/chat/completions")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -1550,37 +1581,6 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
|
|||||||
async def get_app_changelog():
|
async def get_app_changelog():
|
||||||
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
|
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
|
||||||
|
|
||||||
##################################
|
|
||||||
# Embeddings
|
|
||||||
##################################
|
|
||||||
|
|
||||||
@app.post("/api/embeddings")
|
|
||||||
async def embeddings_endpoint(
|
|
||||||
request: Request,
|
|
||||||
form_data: dict,
|
|
||||||
user=Depends(get_verified_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
OpenAI-compatible embeddings endpoint.
|
|
||||||
|
|
||||||
This handler:
|
|
||||||
- Performs user/model checks and dispatches to the correct backend.
|
|
||||||
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (Request): Request context.
|
|
||||||
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
|
|
||||||
user (UserModel): Authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: OpenAI-compatible embeddings response.
|
|
||||||
"""
|
|
||||||
# Make sure models are loaded in app state
|
|
||||||
if not request.app.state.MODELS:
|
|
||||||
await get_all_models(request, user=user)
|
|
||||||
# Use generic dispatcher in utils.embeddings
|
|
||||||
return await generate_embeddings(request, form_data, user)
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# OAuth Login & Callback
|
# OAuth Login & Callback
|
||||||
|
@ -9,9 +9,10 @@ from open_webui.utils.models import check_model_access
|
|||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||||
|
|
||||||
from open_webui.routers.openai import embeddings as openai_embeddings
|
from open_webui.routers.openai import embeddings as openai_embeddings
|
||||||
from open_webui.routers.ollama import embeddings as ollama_embeddings
|
from open_webui.routers.ollama import (
|
||||||
from open_webui.routers.ollama import GenerateEmbeddingsForm
|
embeddings as ollama_embeddings,
|
||||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
GenerateEmbeddingsForm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
||||||
@ -29,7 +30,7 @@ async def generate_embeddings(
|
|||||||
bypass_filter: bool = False,
|
bypass_filter: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
|
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): The FastAPI request context.
|
request (Request): The FastAPI request context.
|
||||||
@ -71,50 +72,12 @@ async def generate_embeddings(
|
|||||||
if not bypass_filter and user.role == "user":
|
if not bypass_filter and user.role == "user":
|
||||||
check_model_access(user, model)
|
check_model_access(user, model)
|
||||||
|
|
||||||
# Arena "meta-model": select a submodel at random
|
|
||||||
if model.get("owned_by") == "arena":
|
|
||||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
||||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
||||||
if model_ids and filter_mode == "exclude":
|
|
||||||
model_ids = [
|
|
||||||
m["id"]
|
|
||||||
for m in list(models.values())
|
|
||||||
if m.get("owned_by") != "arena" and m["id"] not in model_ids
|
|
||||||
]
|
|
||||||
if isinstance(model_ids, list) and model_ids:
|
|
||||||
selected_model_id = random.choice(model_ids)
|
|
||||||
else:
|
|
||||||
model_ids = [
|
|
||||||
m["id"]
|
|
||||||
for m in list(models.values())
|
|
||||||
if m.get("owned_by") != "arena"
|
|
||||||
]
|
|
||||||
selected_model_id = random.choice(model_ids)
|
|
||||||
inner_form = dict(form_data)
|
|
||||||
inner_form["model"] = selected_model_id
|
|
||||||
response = await generate_embeddings(
|
|
||||||
request, inner_form, user, bypass_filter=True
|
|
||||||
)
|
|
||||||
# Tag which concreted model was chosen
|
|
||||||
if isinstance(response, dict):
|
|
||||||
response = {
|
|
||||||
**response,
|
|
||||||
"selected_model_id": selected_model_id,
|
|
||||||
}
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Pipeline/Function models
|
|
||||||
if model.get("pipe"):
|
|
||||||
# The pipeline handler should provide OpenAI-compatible schema
|
|
||||||
return await process_pipeline_inlet_filter(request, form_data, user, models)
|
|
||||||
|
|
||||||
# Ollama backend
|
# Ollama backend
|
||||||
if model.get("owned_by") == "ollama":
|
if model.get("owned_by") == "ollama":
|
||||||
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
||||||
form_obj = GenerateEmbeddingsForm(**ollama_payload)
|
|
||||||
response = await ollama_embeddings(
|
response = await ollama_embeddings(
|
||||||
request=request,
|
request=request,
|
||||||
form_data=form_obj,
|
form_data=GenerateEmbeddingsForm(**ollama_payload),
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return convert_embedding_response_ollama_to_openai(response)
|
return convert_embedding_response_ollama_to_openai(response)
|
||||||
|
Loading…
Reference in New Issue
Block a user