open-webui/backend/open_webui/routers/tasks.py
Timothy Jaeryang Baek d3d161f723 wip
2024-12-10 00:54:13 -08:00

612 lines
21 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
from pydantic import BaseModel
from starlette.responses import FileResponse
from typing import Optional
import logging
from open_webui.utils.task import (
title_generation_template,
query_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.config import (
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
##################################
#
# Task Endpoints
#
##################################
@router.get("/config")
async def get_task_config(request: Request, user=Depends(get_verified_user)):
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@router.post("/config/update")
async def update_task_config(
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
form_data.ENABLE_AUTOCOMPLETE_GENERATION
)
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
)
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
form_data.ENABLE_SEARCH_QUERY_GENERATION
)
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
)
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@router.post("/title/completions")
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
<chat_history>
{{MESSAGES:END:2}}
</chat_history>"""
content = title_generation_template(
template,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 50}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 50,
}
),
"metadata": {
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_TAGS_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Tags generation is disabled"},
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat tags using model {task_model_id} for user {user.email} "
)
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
else:
template = """### Task:
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
### Guidelines:
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
- Use the chat's primary language; default to English if multilingual
- Prioritize accuracy over specificity
### Output:
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
content = tags_generation_template(
template, form_data["messages"], {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@router.post("/queries/completions")
async def generate_queries(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
type = form_data.get("type")
if type == "web_search":
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
elif type == "retrieval":
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Query generation is disabled",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating {type} queries using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
content = query_generation_template(
template, form_data["messages"], {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@router.post("/auto/completions")
async def generate_autocompletion(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Autocompletion generation is disabled",
)
type = form_data.get("type")
prompt = form_data.get("prompt")
messages = form_data.get("messages")
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
if (
len(prompt)
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating autocompletion using model {task_model_id} for user {user.email}"
)
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
content = autocomplete_generation_template(
template, prompt, messages, type, {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@router.post("/emoji/completions")
async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
Message: """{{prompt}}"""
'''
content = emoji_generation_template(
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 4,
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@router.post("/moa/completions")
async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models: {{responses}}"""
content = moa_response_generation_template(
template,
form_data["prompt"],
form_data["responses"],
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"chat_id": form_data.get("chat_id", None),
"metadata": {
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
},
}
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)