feat: follow ups

This commit is contained in:
Timothy Jaeryang Baek 2025-06-03 18:07:29 +04:00
parent f8b941fb96
commit 9e49fbc8bf
6 changed files with 176 additions and 0 deletions

View File

@ -1411,6 +1411,34 @@ Strictly return in JSON format:
{{MESSAGES:END:6}}
</chat_history>"""
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
"task.follow_up.prompt_template",
os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
Suggest 3-5 relevant follow-up questions or discussion prompts based on the chat history to help continue or deepen the conversation.
### Guidelines:
- Make questions concise, clear, and directly related to the discussed topic(s).
- Only generate follow-ups that make sense given the chat content and do not repeat what was already covered.
- If the conversation is very short or not specific, suggest more general follow-ups.
- Use the chat's primary language; default to English if multilingual.
- Response must be a JSON array of strings, no extra text or formatting.
### Output:
JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
"ENABLE_FOLLOW_UP_GENERATION",
"task.follow_up.enable",
os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
)
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",

View File

@ -111,6 +111,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"

View File

@ -359,10 +359,12 @@ from open_webui.config import (
TASK_MODEL_EXTERNAL,
ENABLE_TAGS_GENERATION,
ENABLE_TITLE_GENERATION,
ENABLE_FOLLOW_UP_GENERATION,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
ENABLE_AUTOCOMPLETE_GENERATION,
TITLE_GENERATION_PROMPT_TEMPLATE,
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@ -959,6 +961,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@ -966,6 +969,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE

View File

@ -9,6 +9,7 @@ import re
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
follow_up_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
ENABLE_FOLLOW_UP_GENERATION: bool
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
@ -94,6 +100,13 @@ async def update_task_config(
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
form_data.ENABLE_FOLLOW_UP_GENERATION
)
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
@ -133,6 +146,8 @@ async def update_task_config(
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@ -231,6 +246,86 @@ async def generate_title(
)
@router.post("/follow_up/completions")
async def generate_follow_ups(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Follow-up generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
content = follow_up_generation_template(
template,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.FOLLOW_UP_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)

View File

@ -207,6 +207,24 @@ def title_generation_template(
return template
def follow_up_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:

View File

@ -31,6 +31,8 @@
TASK_MODEL_EXTERNAL: '',
ENABLE_TITLE_GENERATION: true,
TITLE_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_FOLLOW_UP_GENERATION: true,
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: '',
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_AUTOCOMPLETE_GENERATION: true,
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
@ -235,6 +237,32 @@
</div>
{/if}
<div class="mb-2.5 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Follow Up Generation')}
</div>
<Switch bind:state={taskConfig.ENABLE_FOLLOW_UP_GENERATION} />
</div>
{#if taskConfig.ENABLE_FOLLOW_UP_GENERATION}
<div class="mb-2.5">
<div class=" mb-1 text-xs font-medium">{$i18n.t('Follow Up Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt'
)}
/>
</Tooltip>
</div>
{/if}
<div class="mb-2.5 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Tags Generation')}