mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			700 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			700 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
 | 
						|
from fastapi.responses import JSONResponse, RedirectResponse
 | 
						|
 | 
						|
from pydantic import BaseModel
 | 
						|
from typing import Optional
 | 
						|
import logging
 | 
						|
import re
 | 
						|
 | 
						|
from open_webui.utils.chat import generate_chat_completion
 | 
						|
from open_webui.utils.task import (
 | 
						|
    title_generation_template,
 | 
						|
    query_generation_template,
 | 
						|
    image_prompt_generation_template,
 | 
						|
    autocomplete_generation_template,
 | 
						|
    tags_generation_template,
 | 
						|
    emoji_generation_template,
 | 
						|
    moa_response_generation_template,
 | 
						|
)
 | 
						|
from open_webui.utils.auth import get_admin_user, get_verified_user
 | 
						|
from open_webui.constants import TASKS
 | 
						|
 | 
						|
from open_webui.routers.pipelines import process_pipeline_inlet_filter
 | 
						|
from open_webui.utils.filter import (
 | 
						|
    get_sorted_filter_ids,
 | 
						|
    process_filter_functions,
 | 
						|
)
 | 
						|
from open_webui.utils.task import get_task_model_id
 | 
						|
 | 
						|
from open_webui.config import (
 | 
						|
    DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
						|
    DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
						|
    DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
 | 
						|
    DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
						|
    DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
 | 
						|
    DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
 | 
						|
    DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
 | 
						|
)
 | 
						|
from open_webui.env import SRC_LOG_LEVELS
 | 
						|
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
 | 
						|
 | 
						|
router = APIRouter()
 | 
						|
 | 
						|
 | 
						|
##################################
 | 
						|
#
 | 
						|
# Task Endpoints
 | 
						|
#
 | 
						|
##################################
 | 
						|
 | 
						|
 | 
						|
@router.get("/config")
 | 
						|
async def get_task_config(request: Request, user=Depends(get_verified_user)):
 | 
						|
    return {
 | 
						|
        "TASK_MODEL": request.app.state.config.TASK_MODEL,
 | 
						|
        "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
 | 
						|
        "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
 | 
						|
        "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
 | 
						|
        "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
 | 
						|
        "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
 | 
						|
        "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
 | 
						|
        "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
class TaskConfigForm(BaseModel):
 | 
						|
    TASK_MODEL: Optional[str]
 | 
						|
    TASK_MODEL_EXTERNAL: Optional[str]
 | 
						|
    ENABLE_TITLE_GENERATION: bool
 | 
						|
    TITLE_GENERATION_PROMPT_TEMPLATE: str
 | 
						|
    IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
 | 
						|
    ENABLE_AUTOCOMPLETE_GENERATION: bool
 | 
						|
    AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
 | 
						|
    TAGS_GENERATION_PROMPT_TEMPLATE: str
 | 
						|
    ENABLE_TAGS_GENERATION: bool
 | 
						|
    ENABLE_SEARCH_QUERY_GENERATION: bool
 | 
						|
    ENABLE_RETRIEVAL_QUERY_GENERATION: bool
 | 
						|
    QUERY_GENERATION_PROMPT_TEMPLATE: str
 | 
						|
    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
 | 
						|
 | 
						|
 | 
						|
@router.post("/config/update")
 | 
						|
async def update_task_config(
 | 
						|
    request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
 | 
						|
):
 | 
						|
    request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
 | 
						|
    request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
 | 
						|
    request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
 | 
						|
    request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
 | 
						|
        form_data.TITLE_GENERATION_PROMPT_TEMPLATE
 | 
						|
    )
 | 
						|
 | 
						|
    request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
 | 
						|
        form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
 | 
						|
    )
 | 
						|
 | 
						|
    request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
 | 
						|
        form_data.ENABLE_AUTOCOMPLETE_GENERATION
 | 
						|
    )
 | 
						|
    request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
 | 
						|
        form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
 | 
						|
    )
 | 
						|
 | 
						|
    request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
 | 
						|
        form_data.TAGS_GENERATION_PROMPT_TEMPLATE
 | 
						|
    )
 | 
						|
    request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
 | 
						|
    request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
 | 
						|
        form_data.ENABLE_SEARCH_QUERY_GENERATION
 | 
						|
    )
 | 
						|
    request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
 | 
						|
        form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
 | 
						|
    )
 | 
						|
 | 
						|
    request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
 | 
						|
        form_data.QUERY_GENERATION_PROMPT_TEMPLATE
 | 
						|
    )
 | 
						|
    request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
 | 
						|
        form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | 
						|
    )
 | 
						|
 | 
						|
    return {
 | 
						|
        "TASK_MODEL": request.app.state.config.TASK_MODEL,
 | 
						|
        "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
 | 
						|
        "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
 | 
						|
        "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
 | 
						|
        "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
 | 
						|
        "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
 | 
						|
        "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
 | 
						|
        "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
						|
        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@router.post("/title/completions")
 | 
						|
async def generate_title(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
 | 
						|
    if not request.app.state.config.ENABLE_TITLE_GENERATION:
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_200_OK,
 | 
						|
            content={"detail": "Title generation is disabled"},
 | 
						|
        )
 | 
						|
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(
 | 
						|
        f"generating chat title using model {task_model_id} for user {user.email} "
 | 
						|
    )
 | 
						|
 | 
						|
    if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
 | 
						|
        template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
 | 
						|
    else:
 | 
						|
        template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    messages = form_data["messages"]
 | 
						|
 | 
						|
    # Remove reasoning details from the messages
 | 
						|
    for message in messages:
 | 
						|
        message["content"] = re.sub(
 | 
						|
            r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
 | 
						|
            "",
 | 
						|
            message["content"],
 | 
						|
            flags=re.S,
 | 
						|
        ).strip()
 | 
						|
 | 
						|
    content = title_generation_template(
 | 
						|
        template,
 | 
						|
        messages,
 | 
						|
        {
 | 
						|
            "name": user.name,
 | 
						|
            "location": user.info.get("location") if user.info else None,
 | 
						|
        },
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": False,
 | 
						|
        **(
 | 
						|
            {"max_tokens": 1000}
 | 
						|
            if models[task_model_id].get("owned_by") == "ollama"
 | 
						|
            else {
 | 
						|
                "max_completion_tokens": 1000,
 | 
						|
            }
 | 
						|
        ),
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "task": str(TASKS.TITLE_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
            "chat_id": form_data.get("chat_id", None),
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        log.error("Exception occurred", exc_info=True)
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            content={"detail": "An internal error has occurred."},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@router.post("/tags/completions")
 | 
						|
async def generate_chat_tags(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
 | 
						|
    if not request.app.state.config.ENABLE_TAGS_GENERATION:
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_200_OK,
 | 
						|
            content={"detail": "Tags generation is disabled"},
 | 
						|
        )
 | 
						|
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(
 | 
						|
        f"generating chat tags using model {task_model_id} for user {user.email} "
 | 
						|
    )
 | 
						|
 | 
						|
    if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
 | 
						|
        template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
 | 
						|
    else:
 | 
						|
        template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    content = tags_generation_template(
 | 
						|
        template, form_data["messages"], {"name": user.name}
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": False,
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "task": str(TASKS.TAGS_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
            "chat_id": form_data.get("chat_id", None),
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        log.error(f"Error generating chat completion: {e}")
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
						|
            content={"detail": "An internal error has occurred."},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@router.post("/image_prompt/completions")
 | 
						|
async def generate_image_prompt(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(
 | 
						|
        f"generating image prompt using model {task_model_id} for user {user.email} "
 | 
						|
    )
 | 
						|
 | 
						|
    if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
 | 
						|
        template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
 | 
						|
    else:
 | 
						|
        template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    content = image_prompt_generation_template(
 | 
						|
        template,
 | 
						|
        form_data["messages"],
 | 
						|
        user={
 | 
						|
            "name": user.name,
 | 
						|
        },
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": False,
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "task": str(TASKS.IMAGE_PROMPT_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
            "chat_id": form_data.get("chat_id", None),
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        log.error("Exception occurred", exc_info=True)
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            content={"detail": "An internal error has occurred."},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@router.post("/queries/completions")
 | 
						|
async def generate_queries(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
 | 
						|
    type = form_data.get("type")
 | 
						|
    if type == "web_search":
 | 
						|
        if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=f"Search query generation is disabled",
 | 
						|
            )
 | 
						|
    elif type == "retrieval":
 | 
						|
        if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=f"Query generation is disabled",
 | 
						|
            )
 | 
						|
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(
 | 
						|
        f"generating {type} queries using model {task_model_id} for user {user.email}"
 | 
						|
    )
 | 
						|
 | 
						|
    if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
 | 
						|
        template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
 | 
						|
    else:
 | 
						|
        template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    content = query_generation_template(
 | 
						|
        template, form_data["messages"], {"name": user.name}
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": False,
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "task": str(TASKS.QUERY_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
            "chat_id": form_data.get("chat_id", None),
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            content={"detail": str(e)},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@router.post("/auto/completions")
 | 
						|
async def generate_autocompletion(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
    if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=f"Autocompletion generation is disabled",
 | 
						|
        )
 | 
						|
 | 
						|
    type = form_data.get("type")
 | 
						|
    prompt = form_data.get("prompt")
 | 
						|
    messages = form_data.get("messages")
 | 
						|
 | 
						|
    if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
 | 
						|
        if (
 | 
						|
            len(prompt)
 | 
						|
            > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
 | 
						|
        ):
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
 | 
						|
            )
 | 
						|
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(
 | 
						|
        f"generating autocompletion using model {task_model_id} for user {user.email}"
 | 
						|
    )
 | 
						|
 | 
						|
    if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
 | 
						|
        template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
 | 
						|
    else:
 | 
						|
        template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    content = autocomplete_generation_template(
 | 
						|
        template, prompt, messages, type, {"name": user.name}
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": False,
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "task": str(TASKS.AUTOCOMPLETE_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
            "chat_id": form_data.get("chat_id", None),
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        log.error(f"Error generating chat completion: {e}")
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
						|
            content={"detail": "An internal error has occurred."},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@router.post("/emoji/completions")
 | 
						|
async def generate_emoji(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
 | 
						|
 | 
						|
    template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    content = emoji_generation_template(
 | 
						|
        template,
 | 
						|
        form_data["prompt"],
 | 
						|
        {
 | 
						|
            "name": user.name,
 | 
						|
            "location": user.info.get("location") if user.info else None,
 | 
						|
        },
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": False,
 | 
						|
        **(
 | 
						|
            {"max_tokens": 4}
 | 
						|
            if models[task_model_id].get("owned_by") == "ollama"
 | 
						|
            else {
 | 
						|
                "max_completion_tokens": 4,
 | 
						|
            }
 | 
						|
        ),
 | 
						|
        "chat_id": form_data.get("chat_id", None),
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "task": str(TASKS.EMOJI_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            content={"detail": str(e)},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@router.post("/moa/completions")
 | 
						|
async def generate_moa_response(
 | 
						|
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
						|
):
 | 
						|
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data["model"]
 | 
						|
 | 
						|
    if model_id not in models:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_404_NOT_FOUND,
 | 
						|
            detail="Model not found",
 | 
						|
        )
 | 
						|
 | 
						|
    # Check if the user has a custom task model
 | 
						|
    # If the user has a custom task model, use that model
 | 
						|
    task_model_id = get_task_model_id(
 | 
						|
        model_id,
 | 
						|
        request.app.state.config.TASK_MODEL,
 | 
						|
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
						|
        models,
 | 
						|
    )
 | 
						|
 | 
						|
    log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
 | 
						|
 | 
						|
    template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
 | 
						|
 | 
						|
    content = moa_response_generation_template(
 | 
						|
        template,
 | 
						|
        form_data["prompt"],
 | 
						|
        form_data["responses"],
 | 
						|
    )
 | 
						|
 | 
						|
    payload = {
 | 
						|
        "model": task_model_id,
 | 
						|
        "messages": [{"role": "user", "content": content}],
 | 
						|
        "stream": form_data.get("stream", False),
 | 
						|
        "metadata": {
 | 
						|
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
						|
            "chat_id": form_data.get("chat_id", None),
 | 
						|
            "task": str(TASKS.MOA_RESPONSE_GENERATION),
 | 
						|
            "task_body": form_data,
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
    # Process the payload through the pipeline
 | 
						|
    try:
 | 
						|
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
    try:
 | 
						|
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
						|
    except Exception as e:
 | 
						|
        return JSONResponse(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            content={"detail": str(e)},
 | 
						|
        )
 |