from fastapi import APIRouter, Depends, HTTPException, Response, status, Request from fastapi.responses import JSONResponse, RedirectResponse from pydantic import BaseModel from typing import Optional import logging from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, query_generation_template, autocomplete_generation_template, tags_generation_template, emoji_generation_template, moa_response_generation_template, ) from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS from open_webui.routers.pipelines import process_pipeline_inlet_filter from open_webui.utils.task import get_task_model_id from open_webui.config import ( DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, ) from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) router = APIRouter() ################################## # # Task Endpoints # ################################## @router.get("/config") async def get_task_config(request: Request, user=Depends(get_verified_user)): return { "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] TITLE_GENERATION_PROMPT_TEMPLATE: str ENABLE_AUTOCOMPLETE_GENERATION: bool AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str ENABLE_TAGS_GENERATION: bool ENABLE_SEARCH_QUERY_GENERATION: bool ENABLE_RETRIEVAL_QUERY_GENERATION: bool QUERY_GENERATION_PROMPT_TEMPLATE: str TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @router.post("/config/update") async def update_task_config( request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) ): request.app.state.config.TASK_MODEL = form_data.TASK_MODEL request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( form_data.ENABLE_AUTOCOMPLETE_GENERATION ) request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ) request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( form_data.TAGS_GENERATION_PROMPT_TEMPLATE ) request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( form_data.ENABLE_SEARCH_QUERY_GENERATION ) request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( form_data.ENABLE_RETRIEVAL_QUERY_GENERATION ) request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( form_data.QUERY_GENERATION_PROMPT_TEMPLATE ) request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) return { "TASK_MODEL": request.app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @router.post("/title/completions") async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id( model_id, request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) log.debug( f"generating chat title using model {task_model_id} for user {user.email} " ) if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE content = title_generation_template( template, form_data["messages"], { "name": user.name, "location": user.info.get("location") if user.info else None, }, ) payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, **( {"max_tokens": 50} if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 50, } ), "metadata": { "task": str(TASKS.TITLE_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), }, } try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: log.error("Exception occurred", exc_info=True) return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": "An internal error has occurred."}, ) @router.post("/tags/completions") async def generate_chat_tags( request: Request, form_data: dict, user=Depends(get_verified_user) ): if not request.app.state.config.ENABLE_TAGS_GENERATION: return JSONResponse( status_code=status.HTTP_200_OK, content={"detail": "Tags generation is disabled"}, ) models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id( model_id, request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) log.debug( f"generating chat tags using model {task_model_id} for user {user.email} " ) if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE content = tags_generation_template( template, form_data["messages"], {"name": user.name} ) payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { "task": str(TASKS.TAGS_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), }, } try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) @router.post("/queries/completions") async def generate_queries( request: Request, form_data: dict, user=Depends(get_verified_user) ): type = form_data.get("type") if type == "web_search": if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Search query generation is disabled", ) elif type == "retrieval": if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Query generation is disabled", ) models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id( model_id, request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) log.debug( f"generating {type} queries using model {task_model_id} for user {user.email}" ) if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE content = query_generation_template( template, form_data["messages"], {"name": user.name} ) payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { "task": str(TASKS.QUERY_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), }, } try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) @router.post("/auto/completions") async def generate_autocompletion( request: Request, form_data: dict, user=Depends(get_verified_user) ): if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Autocompletion generation is disabled", ) type = form_data.get("type") prompt = form_data.get("prompt") messages = form_data.get("messages") if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: if ( len(prompt) > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id( model_id, request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) log.debug( f"generating autocompletion using model {task_model_id} for user {user.email}" ) if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE content = autocomplete_generation_template( template, prompt, messages, type, {"name": user.name} ) payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { "task": str(TASKS.AUTOCOMPLETE_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), }, } try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) @router.post("/emoji/completions") async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id( model_id, request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE content = emoji_generation_template( template, form_data["prompt"], { "name": user.name, "location": user.info.get("location") if user.info else None, }, ) payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, **( {"max_tokens": 4} if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 4, } ), "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, } try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) @router.post("/moa/completions") async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) # Check if the user has a custom task model # If the user has a custom task model, use that model task_model_id = get_task_model_id( model_id, request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) log.debug(f"generating MOA model {task_model_id} for user {user.email} ") template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE content = moa_response_generation_template( template, form_data["prompt"], form_data["responses"], ) payload = { "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": form_data.get("stream", False), "chat_id": form_data.get("chat_id", None), "metadata": { "task": str(TASKS.MOA_RESPONSE_GENERATION), "task_body": form_data, }, } try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, )