diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0e0c08f2b..47c3620ec 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1411,6 +1411,34 @@ Strictly return in JSON format: {{MESSAGES:END:6}} """ + +FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", + "task.follow_up.prompt_template", + os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task: +Suggest 3-5 relevant follow-up questions or discussion prompts based on the chat history to help continue or deepen the conversation. +### Guidelines: +- Make questions concise, clear, and directly related to the discussed topic(s). +- Only generate follow-ups that make sense given the chat content and do not repeat what was already covered. +- If the conversation is very short or not specific, suggest more general follow-ups. +- Use the chat's primary language; default to English if multilingual. +- Response must be a JSON array of strings, no extra text or formatting. +### Output: +JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] } +### Chat History: + +{{MESSAGES:END:6}} +""" + +ENABLE_FOLLOW_UP_GENERATION = PersistentConfig( + "ENABLE_FOLLOW_UP_GENERATION", + "task.follow_up.enable", + os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true", +) + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 95c54a0d2..59ee6aaac 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -111,6 +111,7 @@ class TASKS(str, Enum): DEFAULT = lambda task="": f"{task if task else 'generation'}" TITLE_GENERATION = "title_generation" + FOLLOW_UP_GENERATION = "follow_up_generation" TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6bdcf4957..a75aebb32 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -359,10 +359,12 @@ from open_webui.config import ( TASK_MODEL_EXTERNAL, ENABLE_TAGS_GENERATION, ENABLE_TITLE_GENERATION, + ENABLE_FOLLOW_UP_GENERATION, ENABLE_SEARCH_QUERY_GENERATION, ENABLE_RETRIEVAL_QUERY_GENERATION, ENABLE_AUTOCOMPLETE_GENERATION, TITLE_GENERATION_PROMPT_TEMPLATE, + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -959,6 +961,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION +app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -966,6 +969,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE ) +app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE +) app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f94346099..3832c0306 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -9,6 +9,7 @@ import re from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, + follow_up_generation_template, query_generation_template, image_prompt_generation_template, autocomplete_generation_template, @@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id from open_webui.config import ( DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, @@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel): ENABLE_AUTOCOMPLETE_GENERATION: bool AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str + ENABLE_FOLLOW_UP_GENERATION: bool ENABLE_TAGS_GENERATION: bool ENABLE_SEARCH_QUERY_GENERATION: bool ENABLE_RETRIEVAL_QUERY_GENERATION: bool @@ -94,6 +100,13 @@ async def update_task_config( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) + request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = ( + form_data.ENABLE_FOLLOW_UP_GENERATION + ) + request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( + form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE ) @@ -133,6 +146,8 @@ async def update_task_config( "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, @@ -231,6 +246,86 @@ async def generate_title( ) +@router.post("/follow_up/completions") +async def generate_follow_ups( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Follow-up generation is disabled"}, + ) + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + + content = follow_up_generation_template( + template, + form_data["messages"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), + "task": str(TASKS.FOLLOW_UP_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) + + @router.post("/tags/completions") async def generate_chat_tags( request: Request, form_data: dict, user=Depends(get_verified_user) diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 95018eef1..42b44d516 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -207,6 +207,24 @@ def title_generation_template( return template +def follow_up_generation_template( + template: str, messages: list[dict], user: Optional[dict] = None +) -> str: + prompt = get_last_user_message(messages) + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def tags_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 548db5a98..2dcab04b3 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -31,6 +31,8 @@ TASK_MODEL_EXTERNAL: '', ENABLE_TITLE_GENERATION: true, TITLE_GENERATION_PROMPT_TEMPLATE: '', + ENABLE_FOLLOW_UP_GENERATION: true, + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: '', IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '', ENABLE_AUTOCOMPLETE_GENERATION: true, AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1, @@ -235,6 +237,32 @@ {/if} +
+
+ {$i18n.t('Follow Up Generation')} +
+ + +
+ + {#if taskConfig.ENABLE_FOLLOW_UP_GENERATION} +
+
{$i18n.t('Follow Up Generation Prompt')}
+ + +