refac: rag prompt template

This commit is contained in:
Timothy Jaeryang Baek 2024-11-24 18:49:56 -08:00
parent 50c3be2136
commit bd28e1ed7d
4 changed files with 63 additions and 53 deletions

View File

@ -15,8 +15,6 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -238,44 +236,6 @@ def query_collection_with_hybrid_search(
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
if template == "":
template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
query_placeholders = []
if "[query]" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
return template
def get_embedding_function(
embedding_engine,
embedding_model,

View File

@ -969,16 +969,16 @@ QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
)
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
Analyze the chat history to determine the necessity of generating search queries. By default, **prioritize generating 1-3 broad and relevant search queries** unless it is absolutely certain that no additional information is required. The aim is to retrieve comprehensive, updated, and valuable information even with minimal uncertainty. If no search is unequivocally needed, return an empty list.
### Guidelines:
- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is prohibited.
- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
- If no search query is necessary, output should be: { "queries": [] }
- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
- Be concise, focusing strictly on composing search queries with no additional commentary or text.
- When in doubt, prefer to suggest a search for comprehensiveness.
- Today's date is: {{CURRENT_DATE}}
- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is strictly prohibited.
- When generating search queries, respond in the format: { "queries": ["query1", "query2"] }, ensuring each query is distinct, concise, and relevant to the topic.
- If and only if it is entirely certain that no useful results can be retrieved by a search, return: { "queries": [] }.
- Err on the side of suggesting search queries if there is **any chance** they might provide useful or updated information.
- Be concise and focused on composing high-quality search queries, avoiding unnecessary elaboration, commentary, or assumptions.
- Assume today's date is: {{CURRENT_DATE}}.
- Always prioritize providing actionable and broad queries that maximize informational coverage.
### Output:
Strictly return in JSON format:

View File

@ -49,7 +49,9 @@ from open_webui.apps.openai.main import (
get_all_models_responses as get_openai_models_responses,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template
from open_webui.apps.retrieval.utils import get_sources_from_files
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
@ -122,11 +124,12 @@ from open_webui.utils.response import (
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
query_generation_template,
emoji_generation_template,
rag_template,
title_generation_template,
query_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
tools_function_calling_generation_template,
)
from open_webui.utils.tools import get_tools

View File

@ -1,11 +1,20 @@
import logging
import math
import re
from datetime import datetime
from typing import Optional
import uuid
from open_webui.utils.misc import get_last_user_message, get_messages_content
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def prompt_template(
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
@ -110,6 +119,44 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
# {{prompt:middletruncate:8000}}
def rag_template(template: str, context: str, query: str):
if template == "":
template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
query_placeholders = []
if "[query]" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
return template
def title_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str: