From e0769c6a1f5ca8f4528f970b2603592c7260cfa5 Mon Sep 17 00:00:00 2001 From: henry Date: Wed, 4 Jun 2025 16:09:39 +0200 Subject: [PATCH] new embedding.py added for handling openai and ollama embedding --- backend/open_webui/utils/embeddings.py | 124 +++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 backend/open_webui/utils/embeddings.py diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py new file mode 100644 index 000000000..c0d570cf3 --- /dev/null +++ b/backend/open_webui/utils/embeddings.py @@ -0,0 +1,124 @@ +import random +import logging +import sys + +from fastapi import Request +from open_webui.models.users import UserModel +from open_webui.models.models import Models +from open_webui.utils.models import check_model_access +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL + +from open_webui.routers.openai import embeddings as openai_embeddings +from open_webui.routers.ollama import embeddings as ollama_embeddings +from open_webui.routers.pipelines import process_pipeline_inlet_filter + +from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama +from open_webui.utils.response import convert_response_ollama_to_openai + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def generate_embeddings( + request: Request, + form_data: dict, + user: UserModel, + bypass_filter: bool = False, +): + """ + Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc). + + Args: + request (Request): The FastAPI request context. + form_data (dict): The input data sent to the endpoint. + user (UserModel): The authenticated user. + bypass_filter (bool): If True, disables access filtering (default False). + + Returns: + dict: The embeddings response, following OpenAI API compatibility. + """ + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + # Attach extra metadata from request.state if present + if hasattr(request.state, "metadata"): + if "metadata" not in form_data: + form_data["metadata"] = request.state.metadata + else: + form_data["metadata"] = { + **form_data["metadata"], + **request.state.metadata, + } + + # If "direct" flag present, use only that model + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + model_id = form_data.get("model") + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + # Access filtering + if not getattr(request.state, "direct", False): + if not bypass_filter and user.role == "user": + check_model_access(user, model) + + # Arena "meta-model": select a submodel at random + if model.get("owned_by") == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + m["id"] + for m in list(models.values()) + if m.get("owned_by") != "arena" and m["id"] not in model_ids + ] + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + m["id"] + for m in list(models.values()) + if m.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + inner_form = dict(form_data) + inner_form["model"] = selected_model_id + response = await generate_embeddings( + request, inner_form, user, bypass_filter=True + ) + # Tag which concreted model was chosen + if isinstance(response, dict): + response = { + **response, + "selected_model_id": selected_model_id, + } + return response + + # Pipeline/Function models + if model.get("pipe"): + # The pipeline handler should provide OpenAI-compatible schema + return await process_pipeline_inlet_filter(request, form_data, user, models) + + # Ollama backend + if model.get("owned_by") == "ollama": + ollama_payload = convert_embedding_payload_openai_to_ollama(form_data) + response = await ollama_embeddings( + request=request, + form_data=ollama_payload, + user=user, + ) + return convert_response_ollama_to_openai(response) + + # Default: OpenAI or compatible backend + return await openai_embeddings( + request=request, + form_data=form_data, + user=user, + ) \ No newline at end of file