open-webui/backend/open_webui/routers/tasks.py

516 lines
17 KiB
Python
Raw Permalink Normal View History

2024-12-10 08:00:01 +00:00
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
2024-12-12 03:52:46 +00:00
from fastapi.responses import JSONResponse, RedirectResponse
2024-12-10 08:00:01 +00:00
from pydantic import BaseModel
from typing import Optional
2024-12-10 08:54:13 +00:00
import logging
2024-12-13 04:22:17 +00:00
from open_webui.utils.chat import generate_chat_completion
2024-12-10 08:00:01 +00:00
from open_webui.utils.task import (
title_generation_template,
query_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
2024-12-10 08:54:13 +00:00
from open_webui.constants import TASKS
2024-12-12 03:52:46 +00:00
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.task import get_task_model_id
2024-12-10 08:54:13 +00:00
from open_webui.config import (
2024-12-13 06:28:42 +00:00
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
2024-12-10 08:54:13 +00:00
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
2024-12-13 06:28:42 +00:00
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
2024-12-10 08:54:13 +00:00
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
2024-12-10 08:00:01 +00:00
router = APIRouter()
##################################
#
# Task Endpoints
#
##################################
@router.get("/config")
async def get_task_config(request: Request, user=Depends(get_verified_user)):
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@router.post("/config/update")
async def update_task_config(
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
form_data.ENABLE_AUTOCOMPLETE_GENERATION
)
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
)
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
form_data.ENABLE_SEARCH_QUERY_GENERATION
)
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
)
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@router.post("/title/completions")
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
2024-12-12 03:52:46 +00:00
models = request.app.state.MODELS
2024-12-10 08:00:01 +00:00
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
2024-12-13 06:28:42 +00:00
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
2024-12-10 08:00:01 +00:00
content = title_generation_template(
template,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 50}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 50,
}
),
"metadata": {
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
2024-12-13 06:28:42 +00:00
return await generate_chat_completion(request, form_data=payload, user=user)
2024-12-10 08:00:01 +00:00
except Exception as e:
log.error("Exception occurred", exc_info=True)
2024-12-13 06:28:42 +00:00
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
2024-12-13 06:28:42 +00:00
)
2024-12-10 08:00:01 +00:00
@router.post("/tags/completions")
2024-12-10 08:54:13 +00:00
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
2024-12-10 08:00:01 +00:00
if not request.app.state.config.ENABLE_TAGS_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Tags generation is disabled"},
)
2024-12-12 03:52:46 +00:00
models = request.app.state.MODELS
2024-12-10 08:00:01 +00:00
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat tags using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
else:
2024-12-13 06:28:42 +00:00
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
2024-12-10 08:00:01 +00:00
content = tags_generation_template(
template, form_data["messages"], {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
2024-12-13 06:28:42 +00:00
return await generate_chat_completion(request, form_data=payload, user=user)
2024-12-10 08:00:01 +00:00
except Exception as e:
log.error(f"Error generating chat completion: {e}")
2024-12-13 06:28:42 +00:00
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
2024-12-13 06:28:42 +00:00
)
2024-12-10 08:00:01 +00:00
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
type = form_data.get("type")
if type == "web_search":
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
elif type == "retrieval":
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Query generation is disabled",
)
2024-12-12 03:52:46 +00:00
models = request.app.state.MODELS
2024-12-10 08:00:01 +00:00
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating {type} queries using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
content = query_generation_template(
template, form_data["messages"], {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
2024-12-13 06:28:42 +00:00
return await generate_chat_completion(request, form_data=payload, user=user)
2024-12-10 08:00:01 +00:00
except Exception as e:
2024-12-13 06:28:42 +00:00
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
2024-12-10 08:00:01 +00:00
@router.post("/auto/completions")
async def generate_autocompletion(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Autocompletion generation is disabled",
)
type = form_data.get("type")
prompt = form_data.get("prompt")
messages = form_data.get("messages")
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
if (
len(prompt)
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
2024-12-12 03:52:46 +00:00
models = request.app.state.MODELS
2024-12-10 08:00:01 +00:00
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating autocompletion using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
content = autocomplete_generation_template(
template, prompt, messages, type, {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
try:
2024-12-13 06:28:42 +00:00
return await generate_chat_completion(request, form_data=payload, user=user)
2024-12-10 08:00:01 +00:00
except Exception as e:
log.error(f"Error generating chat completion: {e}")
2024-12-13 06:28:42 +00:00
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "An internal error has occurred."},
2024-12-13 06:28:42 +00:00
)
2024-12-10 08:00:01 +00:00
@router.post("/emoji/completions")
async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
2024-12-12 03:52:46 +00:00
models = request.app.state.MODELS
2024-12-10 08:00:01 +00:00
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
2024-12-13 06:28:42 +00:00
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
2024-12-10 08:00:01 +00:00
content = emoji_generation_template(
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 4,
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
}
try:
2024-12-13 06:28:42 +00:00
return await generate_chat_completion(request, form_data=payload, user=user)
2024-12-10 08:00:01 +00:00
except Exception as e:
2024-12-13 06:28:42 +00:00
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
2024-12-10 08:00:01 +00:00
@router.post("/moa/completions")
async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
2024-12-12 03:52:46 +00:00
models = request.app.state.MODELS
2024-12-10 08:00:01 +00:00
model_id = form_data["model"]
2024-12-12 03:52:46 +00:00
2024-12-10 08:00:01 +00:00
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
2024-12-13 06:28:42 +00:00
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
2024-12-10 08:00:01 +00:00
content = moa_response_generation_template(
template,
form_data["prompt"],
form_data["responses"],
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"chat_id": form_data.get("chat_id", None),
"metadata": {
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
},
}
try:
2024-12-13 06:28:42 +00:00
return await generate_chat_completion(request, form_data=payload, user=user)
2024-12-10 08:00:01 +00:00
except Exception as e:
2024-12-13 06:28:42 +00:00
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)