mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 14:52:23 +00:00
general refac
This commit is contained in:
parent
1349c6049e
commit
f6bec8d9f3
@ -130,12 +130,6 @@ from open_webui.utils.response import (
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.task import (
|
||||
rag_template,
|
||||
title_generation_template,
|
||||
query_generation_template,
|
||||
autocomplete_generation_template,
|
||||
tags_generation_template,
|
||||
emoji_generation_template,
|
||||
moa_response_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
@ -1263,12 +1257,15 @@ async def get_base_models(user=Depends(get_admin_user)):
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(
|
||||
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
model_list = await get_all_models()
|
||||
model_list = request.state.models
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
@ -1665,574 +1662,6 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
|
||||
return data
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Task Endpoints
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
# TODO: Refactor task API endpoints below into a separate file
|
||||
|
||||
|
||||
@app.get("/api/task/config")
|
||||
async def get_task_config(user=Depends(get_verified_user)):
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_TAGS_GENERATION: bool
|
||||
ENABLE_SEARCH_QUERY_GENERATION: bool
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||
|
||||
|
||||
@app.post("/api/task/config/update")
|
||||
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
||||
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||
)
|
||||
app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||
)
|
||||
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
||||
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
||||
form_data.ENABLE_SEARCH_QUERY_GENERATION
|
||||
)
|
||||
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
||||
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
)
|
||||
|
||||
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/task/title/completions")
|
||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat title using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
||||
|
||||
Examples of titles:
|
||||
📉 Stock Market Trends
|
||||
🍪 Perfect Chocolate Chip Recipe
|
||||
Evolution of Music Streaming
|
||||
Remote Work Productivity Tips
|
||||
Artificial Intelligence in Healthcare
|
||||
🎮 Video Game Development Insights
|
||||
|
||||
<chat_history>
|
||||
{{MESSAGES:END:2}}
|
||||
</chat_history>"""
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 50}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 50,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/tags/completions")
|
||||
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
if not app.state.config.ENABLE_TAGS_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Tags generation is disabled"},
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat tags using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """### Task:
|
||||
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
|
||||
|
||||
### Guidelines:
|
||||
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
||||
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
|
||||
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
||||
- Use the chat's primary language; default to English if multilingual
|
||||
- Prioritize accuracy over specificity
|
||||
|
||||
### Output:
|
||||
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
content = tags_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.TAGS_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/queries/completions")
|
||||
async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
type = form_data.get("type")
|
||||
if type == "web_search":
|
||||
if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Search query generation is disabled",
|
||||
)
|
||||
elif type == "retrieval":
|
||||
if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Query generation is disabled",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating {type} queries using model {task_model_id} for user {user.email}"
|
||||
)
|
||||
|
||||
if (app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
||||
template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = query_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/auto/completions")
|
||||
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
|
||||
if not app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Autocompletion generation is disabled",
|
||||
)
|
||||
|
||||
type = form_data.get("type")
|
||||
prompt = form_data.get("prompt")
|
||||
messages = form_data.get("messages")
|
||||
|
||||
if app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
|
||||
if len(prompt) > app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Input prompt exceeds maximum length of {app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating autocompletion using model {task_model_id} for user {user.email}"
|
||||
)
|
||||
|
||||
if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
||||
template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = autocomplete_generation_template(
|
||||
template, prompt, messages, type, {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/emoji/completions")
|
||||
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
|
||||
|
||||
template = '''
|
||||
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
||||
Message: """{{prompt}}"""
|
||||
'''
|
||||
content = emoji_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/moa/completions")
|
||||
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
|
||||
|
||||
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
||||
|
||||
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
|
||||
Responses from models: {{responses}}"""
|
||||
|
||||
content = moa_response_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
form_data["responses"],
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Pipelines Endpoints
|
||||
|
0
backend/open_webui/routers/chat.py
Normal file
0
backend/open_webui/routers/chat.py
Normal file
99
backend/open_webui/routers/pipelines.py
Normal file
99
backend/open_webui/routers/pipelines.py
Normal file
@ -0,0 +1,99 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
|
||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
from open_webui.utils.misc import get_gravatar_url
|
||||
from open_webui.utils.pdf_generator import PDFGenerator
|
||||
from open_webui.utils.auth import get_admin_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
):
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
|
||||
@router.post("/markdown")
|
||||
async def get_html_from_markdown(
|
||||
form_data: MarkdownForm,
|
||||
):
|
||||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
|
||||
class ChatForm(BaseModel):
|
||||
title: str
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatTitleMessagesForm,
|
||||
):
|
||||
try:
|
||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||
|
||||
return Response(
|
||||
content=pdf_bytes,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/db/download")
|
||||
async def download_db(user=Depends(get_admin_user)):
|
||||
if not ENABLE_ADMIN_EXPORT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
from open_webui.apps.webui.internal.db import engine
|
||||
|
||||
if engine.name != "sqlite":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
|
||||
)
|
||||
return FileResponse(
|
||||
engine.url.database,
|
||||
media_type="application/octet-stream",
|
||||
filename="webui.db",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/litellm/config")
|
||||
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
|
||||
return FileResponse(
|
||||
f"{DATA_DIR}/litellm/config.yaml",
|
||||
media_type="application/octet-stream",
|
||||
filename="config.yaml",
|
||||
)
|
596
backend/open_webui/routers/tasks.py
Normal file
596
backend/open_webui/routers/tasks.py
Normal file
@ -0,0 +1,596 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.utils.task import (
|
||||
title_generation_template,
|
||||
query_generation_template,
|
||||
autocomplete_generation_template,
|
||||
tags_generation_template,
|
||||
emoji_generation_template,
|
||||
moa_response_generation_template,
|
||||
)
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Task Endpoints
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_TAGS_GENERATION: bool
|
||||
ENABLE_SEARCH_QUERY_GENERATION: bool
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_task_config(
|
||||
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
||||
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||
)
|
||||
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||
)
|
||||
|
||||
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
||||
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
||||
form_data.ENABLE_SEARCH_QUERY_GENERATION
|
||||
)
|
||||
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
||||
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
)
|
||||
|
||||
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/title/completions")
|
||||
async def generate_title(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat title using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
||||
|
||||
Examples of titles:
|
||||
📉 Stock Market Trends
|
||||
🍪 Perfect Chocolate Chip Recipe
|
||||
Evolution of Music Streaming
|
||||
Remote Work Productivity Tips
|
||||
Artificial Intelligence in Healthcare
|
||||
🎮 Video Game Development Insights
|
||||
|
||||
<chat_history>
|
||||
{{MESSAGES:END:2}}
|
||||
</chat_history>"""
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 50}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 50,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@router.post("/tags/completions")
|
||||
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
if not request.app.state.config.ENABLE_TAGS_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Tags generation is disabled"},
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat tags using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """### Task:
|
||||
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
|
||||
|
||||
### Guidelines:
|
||||
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
||||
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
|
||||
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
||||
- Use the chat's primary language; default to English if multilingual
|
||||
- Prioritize accuracy over specificity
|
||||
|
||||
### Output:
|
||||
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
content = tags_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.TAGS_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@router.post("/queries/completions")
|
||||
async def generate_queries(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
type = form_data.get("type")
|
||||
if type == "web_search":
|
||||
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Search query generation is disabled",
|
||||
)
|
||||
elif type == "retrieval":
|
||||
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Query generation is disabled",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating {type} queries using model {task_model_id} for user {user.email}"
|
||||
)
|
||||
|
||||
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
||||
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = query_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@router.post("/auto/completions")
|
||||
async def generate_autocompletion(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Autocompletion generation is disabled",
|
||||
)
|
||||
|
||||
type = form_data.get("type")
|
||||
prompt = form_data.get("prompt")
|
||||
messages = form_data.get("messages")
|
||||
|
||||
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
|
||||
if (
|
||||
len(prompt)
|
||||
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating autocompletion using model {task_model_id} for user {user.email}"
|
||||
)
|
||||
|
||||
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
||||
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = autocomplete_generation_template(
|
||||
template, prompt, messages, type, {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@router.post("/emoji/completions")
|
||||
async def generate_emoji(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
|
||||
|
||||
template = '''
|
||||
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
||||
Message: """{{prompt}}"""
|
||||
'''
|
||||
content = emoji_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
|
||||
}
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@router.post("/moa/completions")
|
||||
async def generate_moa_response(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
|
||||
|
||||
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
||||
|
||||
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
|
||||
Responses from models: {{responses}}"""
|
||||
|
||||
content = moa_response_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
form_data["responses"],
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
Binary file not shown.
Before Width: | Height: | Size: 6.0 KiB |
Loading…
Reference in New Issue
Block a user