refac: web search

This commit is contained in:
Timothy J. Baek 2024-09-07 04:50:29 +01:00
parent ff46fe2b4a
commit 5c8fb4b3d5
7 changed files with 142 additions and 103 deletions

View File

@ -59,6 +59,7 @@ async def download_chat_as_pdf(
form_data: ChatForm,
):
global FONTS_DIR
pdf = FPDF()
pdf.add_page()

View File

@ -898,53 +898,27 @@ TASK_MODEL_EXTERNAL = PersistentConfig(
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
)
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
Prompt: {{prompt:middletruncate:8000}}""",
),
ENABLE_SEARCH_QUERY = PersistentConfig(
"ENABLE_SEARCH_QUERY",
"task.search.enable",
os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}""",
),
os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text.""",
),
os.environ.get("TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", ""),
)

View File

@ -63,8 +63,8 @@ from open_webui.config import (
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
@ -199,9 +199,7 @@ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMP
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
@ -397,8 +395,13 @@ async def chat_completion_tools_handler(
specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs)
if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
else:
template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
tools_function_calling_prompt = tools_function_calling_generation_template(
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs
template, tools_specs
)
log.info(f"{tools_function_calling_prompt=}")
payload = get_tools_function_calling_payload(
@ -1312,8 +1315,8 @@ async def get_task_config(user=Depends(get_verified_user)):
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@ -1323,7 +1326,7 @@ class TaskConfigForm(BaseModel):
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
ENABLE_SEARCH_QUERY: bool
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@ -1337,9 +1340,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
@ -1349,7 +1350,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@ -1371,7 +1372,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print(model_id)
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
else:
template = """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
Prompt: {{prompt:middletruncate:8000}}"""
content = title_generation_template(
template,
@ -1416,11 +1430,10 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
if not app.state.config.ENABLE_SEARCH_QUERY:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
detail=f"Search query generation is disabled",
)
model_id = form_data["model"]
@ -1436,12 +1449,22 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
print(model_id)
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
else:
template = """You are tasked with assessing the need for a web search based on the current question and the context provided by the previous interactions. If the question requires a web search, generate an appropriate query for a Google search and respond with only the query. If the question can be answered without a web search or does not require further information, return an empty string. Today's date is {{CURRENT_DATE}}.
Interaction History:
{{MESSAGES:END:6}}
Current Question:
{{prompt:end:4000}}"""
content = search_query_generation_template(
template, form_data["prompt"], {"name": user.name}
template, form_data["messages"], {"name": user.name}
)
print("content", content)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],

View File

@ -4,6 +4,9 @@ from datetime import datetime
from typing import Optional
from open_webui.utils.misc import get_last_user_message, get_messages_content
def prompt_template(
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
) -> str:
@ -37,9 +40,7 @@ def prompt_template(
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replace_prompt_variable(template: str, prompt: str) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
@ -66,7 +67,13 @@ def title_generation_template(
replacement_function,
template,
)
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
template = replace_prompt_variable(template, prompt)
template = prompt_template(
template,
**(
@ -79,36 +86,50 @@ def title_generation_template(
return template
def search_query_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replace_messages_variable(template: str, messages: list[str]) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
# Process messages based on the number of messages required
if full_match == "{{MESSAGES}}":
return get_messages_content(messages)
elif start_length is not None:
return prompt[: int(start_length)]
return get_messages_content(messages[: int(start_length)])
elif end_length is not None:
return prompt[-int(end_length) :]
return get_messages_content(messages[-int(end_length) :])
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
mid = int(middle_length)
if len(messages) <= mid:
return get_messages_content(messages)
# Handle middle truncation: split to get start and end portions of the messages list
half = mid // 2
start_msgs = messages[:half]
end_msgs = messages[-half:] if mid % 2 == 0 else messages[-(half + 1) :]
formatted_start = get_messages_content(start_msgs)
formatted_end = get_messages_content(end_msgs)
return f"{formatted_start}\n{formatted_end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
r"{{MESSAGES}}|{{MESSAGES:START:(\d+)}}|{{MESSAGES:END:(\d+)}}|{{MESSAGES:MIDDLETRUNCATE:(\d+)}}",
replacement_function,
template,
)
return template
def search_query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(

View File

@ -23,8 +23,8 @@
TASK_MODEL: '',
TASK_MODEL_EXTERNAL: '',
TITLE_GENERATION_PROMPT_TEMPLATE: '',
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: '',
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: 0
ENABLE_SEARCH_QUERY: true,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
};
let promptSuggestions = [];
@ -43,7 +43,6 @@
taskConfig = await getTaskConfig(localStorage.token);
promptSuggestions = $config?.default_prompt_suggestions;
banners = await getBanners(localStorage.token);
});
@ -119,33 +118,50 @@
</div>
<div class="mt-3">
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Title Generation Prompt')}</div>
<textarea
bind:value={taskConfig.TITLE_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="6"
/>
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Title Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<textarea
bind:value={taskConfig.TITLE_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="3"
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/>
</Tooltip>
</div>
<div class="mt-3">
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
<textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="6"
/>
</div>
<hr class=" dark:border-gray-850 my-3" />
<div class="mt-3">
<div class=" mb-2.5 text-sm font-medium">
{$i18n.t('Search Query Generation Prompt Length Threshold')}
<div class="my-3 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Enable Web Search Query Generation')}
</div>
<input
bind:value={taskConfig.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD}
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
type="number"
/>
<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY} />
</div>
{#if taskConfig.ENABLE_SEARCH_QUERY}
<div class="">
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
class="w-full rounded-lg py-3 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none resize-none"
rows="3"
placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt'
)}
/>
</Tooltip>
</div>
{/if}
</div>
<hr class=" dark:border-gray-850 my-3" />

View File

@ -732,6 +732,8 @@
responseMessage.userContext = userContext;
const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
scrollToBottom();
if (webSearchEnabled) {
await getWebSearchResults(model.id, parentId, responseMessageId);
}
@ -1512,23 +1514,25 @@
messages = messages;
const prompt = userMessage.content;
let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch(
(error) => {
console.log(error);
return prompt;
}
);
let searchQuery = await generateSearchQuery(
localStorage.token,
model,
messages.filter((message) => message?.content?.trim()),
prompt
).catch((error) => {
console.log(error);
return prompt;
});
if (!searchQuery) {
toast.warning($i18n.t('No search query generated'));
if (!searchQuery || searchQuery == '') {
responseMessage.statusHistory.push({
done: true,
error: true,
action: 'web_search',
description: 'No search query generated'
});
messages = messages;
return;
}
responseMessage.statusHistory.push({

View File

@ -309,7 +309,7 @@
{:else}
<div class="w-full pt-2">
{#key chatId}
{#each messages as message, messageIdx}
{#each messages as message, messageIdx (message.id)}
<div class=" w-full {messageIdx === messages.length - 1 ? ' pb-12' : ''}">
<div
class="flex flex-col justify-between px-5 mb-3 {($settings?.widescreenMode ?? null)