refac
This commit is contained in:
@@ -74,6 +74,7 @@ from open_webui.utils.misc import (
|
||||
add_or_update_user_message,
|
||||
get_last_user_message,
|
||||
get_last_assistant_message,
|
||||
get_system_message,
|
||||
prepend_to_first_user_message_content,
|
||||
convert_logit_bias_input_to_json,
|
||||
)
|
||||
@@ -84,7 +85,7 @@ from open_webui.utils.filter import (
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.utils.payload import apply_model_system_prompt_to_body
|
||||
from open_webui.utils.payload import apply_system_prompt_to_body
|
||||
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -737,6 +738,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
system_message = get_system_message(form_data.get("messages", []))
|
||||
if system_message:
|
||||
form_data = apply_system_prompt_to_body(
|
||||
system_message.get("content"), form_data, metadata, user
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
@@ -778,7 +785,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
|
||||
if folder and folder.data:
|
||||
if "system_prompt" in folder.data:
|
||||
form_data = apply_model_system_prompt_to_body(
|
||||
form_data = apply_system_prompt_to_body(
|
||||
folder.data["system_prompt"], form_data, metadata, user
|
||||
)
|
||||
if "files" in folder.data:
|
||||
|
||||
@@ -9,7 +9,7 @@ import json
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(
|
||||
def apply_system_prompt_to_body(
|
||||
system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None
|
||||
) -> dict:
|
||||
if not system:
|
||||
@@ -22,15 +22,7 @@ def apply_model_system_prompt_to_body(
|
||||
system = prompt_variables_template(system, variables)
|
||||
|
||||
# Legacy (API Usage)
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
|
||||
system = prompt_template(system, **template_params)
|
||||
system = prompt_template(system, user)
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -38,9 +38,42 @@ def prompt_variables_template(template: str, variables: dict[str, str]) -> str:
|
||||
return template
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||
) -> str:
|
||||
def prompt_template(template: str, user: Optional[Any] = None) -> str:
|
||||
if hasattr(user, "model_dump"):
|
||||
user = user.model_dump()
|
||||
|
||||
USER_VARIABLES = {}
|
||||
|
||||
if isinstance(user, dict):
|
||||
birth_date = user.get("date_of_birth")
|
||||
age = None
|
||||
|
||||
if birth_date:
|
||||
try:
|
||||
# If birth_date is str, convert to datetime
|
||||
if isinstance(birth_date, str):
|
||||
birth_date = datetime.strptime(birth_date, "%Y-%m-%d")
|
||||
|
||||
today = datetime.now()
|
||||
age = (
|
||||
today.year
|
||||
- birth_date.year
|
||||
- ((today.month, today.day) < (birth_date.month, birth_date.day))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
USER_VARIABLES = {
|
||||
"name": str(user.get("name")),
|
||||
"location": str(user.get("info", {}).get("location")),
|
||||
"bio": str(user.get("bio")),
|
||||
"gender": str(user.get("gender")),
|
||||
"birth_date": str(birth_date),
|
||||
"age": str(age),
|
||||
}
|
||||
|
||||
print(USER_VARIABLES)
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
|
||||
@@ -56,19 +89,20 @@ def prompt_template(
|
||||
)
|
||||
template = template.replace("{{CURRENT_WEEKDAY}}", formatted_weekday)
|
||||
|
||||
if user_name:
|
||||
# Replace {{USER_NAME}} in the template with the user's name
|
||||
template = template.replace("{{USER_NAME}}", user_name)
|
||||
else:
|
||||
# Replace {{USER_NAME}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_NAME}}", "Unknown")
|
||||
|
||||
if user_location:
|
||||
# Replace {{USER_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{USER_LOCATION}}", user_location)
|
||||
else:
|
||||
# Replace {{USER_LOCATION}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_LOCATION}}", "Unknown")
|
||||
template = template.replace("{{USER_NAME}}", USER_VARIABLES.get("name", "Unknown"))
|
||||
template = template.replace("{{USER_BIO}}", USER_VARIABLES.get("bio", "Unknown"))
|
||||
template = template.replace(
|
||||
"{{USER_GENDER}}", USER_VARIABLES.get("gender", "Unknown")
|
||||
)
|
||||
template = template.replace(
|
||||
"{{USER_BIRTH_DATE}}", USER_VARIABLES.get("birth_date", "Unknown")
|
||||
)
|
||||
template = template.replace(
|
||||
"{{USER_AGE}}", str(USER_VARIABLES.get("age", "Unknown"))
|
||||
)
|
||||
template = template.replace(
|
||||
"{{USER_LOCATION}}", USER_VARIABLES.get("location", "Unknown")
|
||||
)
|
||||
|
||||
return template
|
||||
|
||||
@@ -189,90 +223,56 @@ def rag_template(template: str, context: str, query: str):
|
||||
|
||||
|
||||
def title_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def follow_up_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def image_prompt_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def emoji_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
template: str, prompt: str, user: Optional[Any] = None
|
||||
) -> str:
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
|
||||
return template
|
||||
|
||||
@@ -282,38 +282,24 @@ def autocomplete_generation_template(
|
||||
prompt: str,
|
||||
messages: Optional[list[dict]] = None,
|
||||
type: Optional[str] = None,
|
||||
user: Optional[dict] = None,
|
||||
user: Optional[Any] = None,
|
||||
) -> str:
|
||||
template = template.replace("{{TYPE}}", type if type else "")
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
def query_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
template: str, messages: list[dict], user: Optional[Any] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
template = prompt_template(template, user)
|
||||
return template
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user