mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: follow ups
This commit is contained in:
		
							parent
							
								
									f8b941fb96
								
							
						
					
					
						commit
						9e49fbc8bf
					
				@ -1411,6 +1411,34 @@ Strictly return in JSON format:
 | 
			
		||||
{{MESSAGES:END:6}}
 | 
			
		||||
</chat_history>"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | 
			
		||||
    "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
 | 
			
		||||
    "task.follow_up.prompt_template",
 | 
			
		||||
    os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
 | 
			
		||||
Suggest 3-5 relevant follow-up questions or discussion prompts based on the chat history to help continue or deepen the conversation.
 | 
			
		||||
### Guidelines:
 | 
			
		||||
- Make questions concise, clear, and directly related to the discussed topic(s).
 | 
			
		||||
- Only generate follow-ups that make sense given the chat content and do not repeat what was already covered.
 | 
			
		||||
- If the conversation is very short or not specific, suggest more general follow-ups.
 | 
			
		||||
- Use the chat's primary language; default to English if multilingual.
 | 
			
		||||
- Response must be a JSON array of strings, no extra text or formatting.
 | 
			
		||||
### Output:
 | 
			
		||||
JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
 | 
			
		||||
### Chat History:
 | 
			
		||||
<chat_history>
 | 
			
		||||
{{MESSAGES:END:6}}
 | 
			
		||||
</chat_history>"""
 | 
			
		||||
 | 
			
		||||
ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
 | 
			
		||||
    "ENABLE_FOLLOW_UP_GENERATION",
 | 
			
		||||
    "task.follow_up.enable",
 | 
			
		||||
    os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ENABLE_TAGS_GENERATION = PersistentConfig(
 | 
			
		||||
    "ENABLE_TAGS_GENERATION",
 | 
			
		||||
    "task.tags.enable",
 | 
			
		||||
 | 
			
		||||
@ -111,6 +111,7 @@ class TASKS(str, Enum):
 | 
			
		||||
 | 
			
		||||
    DEFAULT = lambda task="": f"{task if task else 'generation'}"
 | 
			
		||||
    TITLE_GENERATION = "title_generation"
 | 
			
		||||
    FOLLOW_UP_GENERATION = "follow_up_generation"
 | 
			
		||||
    TAGS_GENERATION = "tags_generation"
 | 
			
		||||
    EMOJI_GENERATION = "emoji_generation"
 | 
			
		||||
    QUERY_GENERATION = "query_generation"
 | 
			
		||||
 | 
			
		||||
@ -359,10 +359,12 @@ from open_webui.config import (
 | 
			
		||||
    TASK_MODEL_EXTERNAL,
 | 
			
		||||
    ENABLE_TAGS_GENERATION,
 | 
			
		||||
    ENABLE_TITLE_GENERATION,
 | 
			
		||||
    ENABLE_FOLLOW_UP_GENERATION,
 | 
			
		||||
    ENABLE_SEARCH_QUERY_GENERATION,
 | 
			
		||||
    ENABLE_RETRIEVAL_QUERY_GENERATION,
 | 
			
		||||
    ENABLE_AUTOCOMPLETE_GENERATION,
 | 
			
		||||
    TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
 | 
			
		||||
@ -959,6 +961,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
 | 
			
		||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
 | 
			
		||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
 | 
			
		||||
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
 | 
			
		||||
app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
@ -966,6 +969,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
 | 
			
		||||
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
    IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
)
 | 
			
		||||
app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
    FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
 | 
			
		||||
    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ import re
 | 
			
		||||
from open_webui.utils.chat import generate_chat_completion
 | 
			
		||||
from open_webui.utils.task import (
 | 
			
		||||
    title_generation_template,
 | 
			
		||||
    follow_up_generation_template,
 | 
			
		||||
    query_generation_template,
 | 
			
		||||
    image_prompt_generation_template,
 | 
			
		||||
    autocomplete_generation_template,
 | 
			
		||||
@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
 | 
			
		||||
 | 
			
		||||
from open_webui.config import (
 | 
			
		||||
    DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
        "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
 | 
			
		||||
        "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
 | 
			
		||||
        "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
        "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
        "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
 | 
			
		||||
        "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
 | 
			
		||||
        "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
 | 
			
		||||
        "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
 | 
			
		||||
@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
 | 
			
		||||
    ENABLE_AUTOCOMPLETE_GENERATION: bool
 | 
			
		||||
    AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
 | 
			
		||||
    TAGS_GENERATION_PROMPT_TEMPLATE: str
 | 
			
		||||
    FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
 | 
			
		||||
    ENABLE_FOLLOW_UP_GENERATION: bool
 | 
			
		||||
    ENABLE_TAGS_GENERATION: bool
 | 
			
		||||
    ENABLE_SEARCH_QUERY_GENERATION: bool
 | 
			
		||||
    ENABLE_RETRIEVAL_QUERY_GENERATION: bool
 | 
			
		||||
@ -94,6 +100,13 @@ async def update_task_config(
 | 
			
		||||
        form_data.TITLE_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
 | 
			
		||||
        form_data.ENABLE_FOLLOW_UP_GENERATION
 | 
			
		||||
    )
 | 
			
		||||
    request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
        form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
        form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
    )
 | 
			
		||||
@ -133,6 +146,8 @@ async def update_task_config(
 | 
			
		||||
        "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
 | 
			
		||||
        "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
        "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
 | 
			
		||||
        "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
 | 
			
		||||
        "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
        "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
 | 
			
		||||
        "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
 | 
			
		||||
        "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
@ -231,6 +246,86 @@ async def generate_title(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/follow_up/completions")
 | 
			
		||||
async def generate_follow_ups(
 | 
			
		||||
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
 | 
			
		||||
        return JSONResponse(
 | 
			
		||||
            status_code=status.HTTP_200_OK,
 | 
			
		||||
            content={"detail": "Follow-up generation is disabled"},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
			
		||||
        models = {
 | 
			
		||||
            request.state.model["id"]: request.state.model,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
        models = request.app.state.MODELS
 | 
			
		||||
 | 
			
		||||
    model_id = form_data["model"]
 | 
			
		||||
    if model_id not in models:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_404_NOT_FOUND,
 | 
			
		||||
            detail="Model not found",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Check if the user has a custom task model
 | 
			
		||||
    # If the user has a custom task model, use that model
 | 
			
		||||
    task_model_id = get_task_model_id(
 | 
			
		||||
        model_id,
 | 
			
		||||
        request.app.state.config.TASK_MODEL,
 | 
			
		||||
        request.app.state.config.TASK_MODEL_EXTERNAL,
 | 
			
		||||
        models,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        f"generating chat title using model {task_model_id} for user {user.email} "
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
 | 
			
		||||
        template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
    else:
 | 
			
		||||
        template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
 | 
			
		||||
    content = follow_up_generation_template(
 | 
			
		||||
        template,
 | 
			
		||||
        form_data["messages"],
 | 
			
		||||
        {
 | 
			
		||||
            "name": user.name,
 | 
			
		||||
            "location": user.info.get("location") if user.info else None,
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        "model": task_model_id,
 | 
			
		||||
        "messages": [{"role": "user", "content": content}],
 | 
			
		||||
        "stream": False,
 | 
			
		||||
        "metadata": {
 | 
			
		||||
            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
 | 
			
		||||
            "task": str(TASKS.FOLLOW_UP_GENERATION),
 | 
			
		||||
            "task_body": form_data,
 | 
			
		||||
            "chat_id": form_data.get("chat_id", None),
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Process the payload through the pipeline
 | 
			
		||||
    try:
 | 
			
		||||
        payload = await process_pipeline_inlet_filter(request, payload, user, models)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        raise e
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        return await generate_chat_completion(request, form_data=payload, user=user)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.error("Exception occurred", exc_info=True)
 | 
			
		||||
        return JSONResponse(
 | 
			
		||||
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
            content={"detail": "An internal error has occurred."},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/tags/completions")
 | 
			
		||||
async def generate_chat_tags(
 | 
			
		||||
    request: Request, form_data: dict, user=Depends(get_verified_user)
 | 
			
		||||
 | 
			
		||||
@ -207,6 +207,24 @@ def title_generation_template(
 | 
			
		||||
    return template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def follow_up_generation_template(
 | 
			
		||||
    template: str, messages: list[dict], user: Optional[dict] = None
 | 
			
		||||
) -> str:
 | 
			
		||||
    prompt = get_last_user_message(messages)
 | 
			
		||||
    template = replace_prompt_variable(template, prompt)
 | 
			
		||||
    template = replace_messages_variable(template, messages)
 | 
			
		||||
 | 
			
		||||
    template = prompt_template(
 | 
			
		||||
        template,
 | 
			
		||||
        **(
 | 
			
		||||
            {"user_name": user.get("name"), "user_location": user.get("location")}
 | 
			
		||||
            if user
 | 
			
		||||
            else {}
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    return template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tags_generation_template(
 | 
			
		||||
    template: str, messages: list[dict], user: Optional[dict] = None
 | 
			
		||||
) -> str:
 | 
			
		||||
 | 
			
		||||
@ -31,6 +31,8 @@
 | 
			
		||||
		TASK_MODEL_EXTERNAL: '',
 | 
			
		||||
		ENABLE_TITLE_GENERATION: true,
 | 
			
		||||
		TITLE_GENERATION_PROMPT_TEMPLATE: '',
 | 
			
		||||
		ENABLE_FOLLOW_UP_GENERATION: true,
 | 
			
		||||
		FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: '',
 | 
			
		||||
		IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
 | 
			
		||||
		ENABLE_AUTOCOMPLETE_GENERATION: true,
 | 
			
		||||
		AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
 | 
			
		||||
@ -235,6 +237,32 @@
 | 
			
		||||
					</div>
 | 
			
		||||
				{/if}
 | 
			
		||||
 | 
			
		||||
				<div class="mb-2.5 flex w-full items-center justify-between">
 | 
			
		||||
					<div class=" self-center text-xs font-medium">
 | 
			
		||||
						{$i18n.t('Follow Up Generation')}
 | 
			
		||||
					</div>
 | 
			
		||||
 | 
			
		||||
					<Switch bind:state={taskConfig.ENABLE_FOLLOW_UP_GENERATION} />
 | 
			
		||||
				</div>
 | 
			
		||||
 | 
			
		||||
				{#if taskConfig.ENABLE_FOLLOW_UP_GENERATION}
 | 
			
		||||
					<div class="mb-2.5">
 | 
			
		||||
						<div class=" mb-1 text-xs font-medium">{$i18n.t('Follow Up Generation Prompt')}</div>
 | 
			
		||||
 | 
			
		||||
						<Tooltip
 | 
			
		||||
							content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
 | 
			
		||||
							placement="top-start"
 | 
			
		||||
						>
 | 
			
		||||
							<Textarea
 | 
			
		||||
								bind:value={taskConfig.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE}
 | 
			
		||||
								placeholder={$i18n.t(
 | 
			
		||||
									'Leave empty to use the default prompt, or enter a custom prompt'
 | 
			
		||||
								)}
 | 
			
		||||
							/>
 | 
			
		||||
						</Tooltip>
 | 
			
		||||
					</div>
 | 
			
		||||
				{/if}
 | 
			
		||||
 | 
			
		||||
				<div class="mb-2.5 flex w-full items-center justify-between">
 | 
			
		||||
					<div class=" self-center text-xs font-medium">
 | 
			
		||||
						{$i18n.t('Tags Generation')}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user