open-webui/backend/open_webui/routers/tasks.py
Timothy Jaeryang Baek 87ba39df57
Some checks are pending
Deploy to HuggingFace Spaces / check-secret (push) Waiting to run
Deploy to HuggingFace Spaces / deploy (push) Blocked by required conditions
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / merge-main-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-ollama-images (push) Blocked by required conditions
Python CI / Format Backend (3.11) (push) Waiting to run
Frontend Build / Format & Build Frontend (push) Waiting to run
Frontend Build / Frontend Unit Tests (push) Waiting to run
Integration Test / Run Cypress Integration Tests (push) Waiting to run
Integration Test / Run Migration Tests (push) Waiting to run
fix: moa
2025-01-02 14:32:25 -08:00

516 lines
17 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from typing import Optional
import logging
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
query_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
##################################
#
# Task Endpoints
#
##################################
@router.get("/config")
async def get_task_config(request: Request, user=Depends(get_verified_user)):
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@router.post("/config/update")
async def update_task_config(
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
form_data.ENABLE_AUTOCOMPLETE_GENERATION
)
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
)
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
form_data.ENABLE_SEARCH_QUERY_GENERATION
)
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
)
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@router.post("/title/completions")
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 50}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 50,
}
),
"metadata": {
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_TAGS_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Tags generation is disabled"},
)
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat tags using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
content = tags_generation_template(
template, form_data["messages"], {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error(f"Error generating chat completion: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
)
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
type = form_data.get("type")
if type == "web_search":
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
elif type == "retrieval":
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Query generation is disabled",
)
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating {type} queries using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
content = query_generation_template(
template, form_data["messages"], {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
@router.post("/auto/completions")
async def generate_autocompletion(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Autocompletion generation is disabled",
)
type = form_data.get("type")
prompt = form_data.get("prompt")
messages = form_data.get("messages")
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
if (
len(prompt)
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating autocompletion using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
content = autocomplete_generation_template(
template, prompt, messages, type, {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error(f"Error generating chat completion: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
)
@router.post("/emoji/completions")
async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
content = emoji_generation_template(
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 4,
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
@router.post("/moa/completions")
async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
content = moa_response_generation_template(
template,
form_data["prompt"],
form_data["responses"],
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
},
}
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)