mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	openai embeddings function modified
This commit is contained in:
		
							parent
							
								
									8f6c3f46d6
								
							
						
					
					
						commit
						3ddebefca2
					
				@ -411,6 +411,7 @@ from open_webui.utils.chat import (
 | 
			
		||||
    chat_completed as chat_completed_handler,
 | 
			
		||||
    chat_action as chat_action_handler,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.utils.embeddings import generate_embeddings
 | 
			
		||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
 | 
			
		||||
from open_webui.utils.access_control import has_access
 | 
			
		||||
 | 
			
		||||
@ -1363,11 +1364,6 @@ async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified
 | 
			
		||||
    return {"task_ids": task_ids}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/embeddings")
 | 
			
		||||
async def api_embeddings(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
    return await openai.generate_embeddings(request=request, user=user)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##################################
 | 
			
		||||
#
 | 
			
		||||
# Config Endpoints
 | 
			
		||||
@ -1544,6 +1540,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
 | 
			
		||||
async def get_app_changelog():
 | 
			
		||||
    return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
 | 
			
		||||
 | 
			
		||||
##################################
 | 
			
		||||
# Embeddings
 | 
			
		||||
##################################
 | 
			
		||||
 | 
			
		||||
@app.post("/api/embeddings")
 | 
			
		||||
async def embeddings_endpoint(
 | 
			
		||||
    request: Request,
 | 
			
		||||
    form_data: dict,
 | 
			
		||||
    user=Depends(get_verified_user)
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    OpenAI-compatible embeddings endpoint.
 | 
			
		||||
 | 
			
		||||
    This handler:
 | 
			
		||||
      - Performs user/model checks and dispatches to the correct backend.
 | 
			
		||||
      - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        request (Request): Request context.
 | 
			
		||||
        form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
 | 
			
		||||
        user (UserModel): Authenticated user.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        dict: OpenAI-compatible embeddings response.
 | 
			
		||||
    """
 | 
			
		||||
    # Make sure models are loaded in app state
 | 
			
		||||
    if not request.app.state.MODELS:
 | 
			
		||||
        await get_all_models(request, user=user)
 | 
			
		||||
    # Use generic dispatcher in utils.embeddings
 | 
			
		||||
    return await generate_embeddings(request, form_data, user)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
############################
 | 
			
		||||
# OAuth Login & Callback
 | 
			
		||||
 | 
			
		||||
@ -886,26 +886,36 @@ async def generate_chat_completion(
 | 
			
		||||
                r.close()
 | 
			
		||||
            await session.close()
 | 
			
		||||
 | 
			
		||||
@router.post("/embeddings")
 | 
			
		||||
async def generate_embeddings(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
async def embeddings(request: Request, form_data: dict, user):
 | 
			
		||||
    """
 | 
			
		||||
    Call embeddings endpoint
 | 
			
		||||
    Calls the embeddings endpoint for OpenAI-compatible providers.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        request (Request): The FastAPI request context.
 | 
			
		||||
        form_data (dict): OpenAI-compatible embeddings payload.
 | 
			
		||||
        user (UserModel): The authenticated user.
 | 
			
		||||
    
 | 
			
		||||
    Returns:
 | 
			
		||||
        dict: OpenAI-compatible embeddings response.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    body = await request.body()
 | 
			
		||||
 | 
			
		||||
    idx = 0
 | 
			
		||||
    # Prepare payload/body
 | 
			
		||||
    body = json.dumps(form_data)
 | 
			
		||||
    # Find correct backend url/key based on model
 | 
			
		||||
    await get_all_models(request, user=user)
 | 
			
		||||
    model_id = form_data.get("model")
 | 
			
		||||
    models = request.app.state.OPENAI_MODELS
 | 
			
		||||
    if model_id in models:
 | 
			
		||||
        idx = models[model_id]["urlIdx"]
 | 
			
		||||
    url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
 | 
			
		||||
    key = request.app.state.config.OPENAI_API_KEYS[idx]
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
    session = None
 | 
			
		||||
    streaming = False
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        session = aiohttp.ClientSession(trust_env=True)
 | 
			
		||||
        r = await session.request(
 | 
			
		||||
            method=request.method,
 | 
			
		||||
            method="POST",
 | 
			
		||||
            url=f"{url}/embeddings",
 | 
			
		||||
            data=body,
 | 
			
		||||
            headers={
 | 
			
		||||
@ -918,14 +928,11 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
 | 
			
		||||
                        "X-OpenWebUI-User-Email": user.email,
 | 
			
		||||
                        "X-OpenWebUI-User-Role": user.role,
 | 
			
		||||
                    }
 | 
			
		||||
                    if ENABLE_FORWARD_USER_INFO_HEADERS
 | 
			
		||||
                    else {}
 | 
			
		||||
                    if ENABLE_FORWARD_USER_INFO_HEADERS and user else {}
 | 
			
		||||
                ),
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        # Check if response is SSE
 | 
			
		||||
        if "text/event-stream" in r.headers.get("Content-Type", ""):
 | 
			
		||||
            streaming = True
 | 
			
		||||
            return StreamingResponse(
 | 
			
		||||
@ -939,10 +946,8 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
 | 
			
		||||
        else:
 | 
			
		||||
            response_data = await r.json()
 | 
			
		||||
            return response_data
 | 
			
		||||
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(e)
 | 
			
		||||
 | 
			
		||||
        detail = None
 | 
			
		||||
        if r is not None:
 | 
			
		||||
            try:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user