mirror of
https://github.com/open-webui/open-webui
synced 2024-11-25 05:18:15 +00:00
enh: retrieval query generation
This commit is contained in:
parent
09c6e4b92f
commit
dbb67a12ca
@ -177,35 +177,34 @@ def merge_and_sort_query_results(
|
|||||||
|
|
||||||
def query_collection(
|
def query_collection(
|
||||||
collection_names: list[str],
|
collection_names: list[str],
|
||||||
query: str,
|
queries: list[str],
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k: int,
|
k: int,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
query_embedding = embedding_function(query)
|
for query in queries:
|
||||||
|
query_embedding = embedding_function(query)
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
if collection_name:
|
if collection_name:
|
||||||
try:
|
try:
|
||||||
result = query_doc(
|
result = query_doc(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
k=k,
|
k=k,
|
||||||
query_embedding=query_embedding,
|
query_embedding=query_embedding,
|
||||||
)
|
)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
results.append(result.model_dump())
|
results.append(result.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error when querying the collection: {e}")
|
log.exception(f"Error when querying the collection: {e}")
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return merge_and_sort_query_results(results, k=k)
|
return merge_and_sort_query_results(results, k=k)
|
||||||
|
|
||||||
|
|
||||||
def query_collection_with_hybrid_search(
|
def query_collection_with_hybrid_search(
|
||||||
collection_names: list[str],
|
collection_names: list[str],
|
||||||
query: str,
|
queries: list[str],
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k: int,
|
k: int,
|
||||||
reranking_function,
|
reranking_function,
|
||||||
@ -215,15 +214,16 @@ def query_collection_with_hybrid_search(
|
|||||||
error = False
|
error = False
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
result = query_doc_with_hybrid_search(
|
for query in queries:
|
||||||
collection_name=collection_name,
|
result = query_doc_with_hybrid_search(
|
||||||
query=query,
|
collection_name=collection_name,
|
||||||
embedding_function=embedding_function,
|
query=query,
|
||||||
k=k,
|
embedding_function=embedding_function,
|
||||||
reranking_function=reranking_function,
|
k=k,
|
||||||
r=r,
|
reranking_function=reranking_function,
|
||||||
)
|
r=r,
|
||||||
results.append(result)
|
)
|
||||||
|
results.append(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(
|
log.exception(
|
||||||
"Error when querying the collection with " f"hybrid_search: {e}"
|
"Error when querying the collection with " f"hybrid_search: {e}"
|
||||||
@ -309,15 +309,14 @@ def get_embedding_function(
|
|||||||
|
|
||||||
def get_rag_context(
|
def get_rag_context(
|
||||||
files,
|
files,
|
||||||
messages,
|
queries,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k,
|
k,
|
||||||
reranking_function,
|
reranking_function,
|
||||||
r,
|
r,
|
||||||
hybrid_search,
|
hybrid_search,
|
||||||
):
|
):
|
||||||
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
|
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
||||||
query = get_last_user_message(messages)
|
|
||||||
|
|
||||||
extracted_collections = []
|
extracted_collections = []
|
||||||
relevant_contexts = []
|
relevant_contexts = []
|
||||||
@ -359,7 +358,7 @@ def get_rag_context(
|
|||||||
try:
|
try:
|
||||||
context = query_collection_with_hybrid_search(
|
context = query_collection_with_hybrid_search(
|
||||||
collection_names=collection_names,
|
collection_names=collection_names,
|
||||||
query=query,
|
queries=queries,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
k=k,
|
k=k,
|
||||||
reranking_function=reranking_function,
|
reranking_function=reranking_function,
|
||||||
@ -374,7 +373,7 @@ def get_rag_context(
|
|||||||
if (not hybrid_search) or (context is None):
|
if (not hybrid_search) or (context is None):
|
||||||
context = query_collection(
|
context = query_collection(
|
||||||
collection_names=collection_names,
|
collection_names=collection_names,
|
||||||
query=query,
|
queries=queries,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
k=k,
|
k=k,
|
||||||
)
|
)
|
||||||
|
@ -941,19 +941,49 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
|
|||||||
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
ENABLE_SEARCH_QUERY = PersistentConfig(
|
|
||||||
"ENABLE_SEARCH_QUERY",
|
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
|
||||||
"task.search.enable",
|
"ENABLE_SEARCH_QUERY_GENERATION",
|
||||||
os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
|
"task.query.search.enable",
|
||||||
|
os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
|
||||||
|
"ENABLE_RETRIEVAL_QUERY_GENERATION",
|
||||||
|
"task.query.retrieval.enable",
|
||||||
|
os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
"QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||||
"task.search.prompt_template",
|
"task.query.prompt_template",
|
||||||
os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
|
os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||||
|
Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
|
||||||
|
|
||||||
|
### Guidelines:
|
||||||
|
- Respond exclusively with a JSON object.
|
||||||
|
- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
|
||||||
|
- If no search query is necessary, output should be: { "queries": [] }
|
||||||
|
- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
|
||||||
|
- Be concise, focusing strictly on composing search queries with no additional commentary or text.
|
||||||
|
- When in doubt, prefer to suggest a search for comprehensiveness.
|
||||||
|
- Today's date is: {{CURRENT_DATE}}
|
||||||
|
|
||||||
|
### Output:
|
||||||
|
JSON format: {
|
||||||
|
"queries": ["query1", "query2"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Chat History:
|
||||||
|
<chat_history>
|
||||||
|
{{MESSAGES:END:6}}
|
||||||
|
</chat_history>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||||
@ -1127,27 +1157,6 @@ RAG_TEXT_SPLITTER = PersistentConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ENABLE_RAG_QUERY_GENERATION = PersistentConfig(
|
|
||||||
"ENABLE_RAG_QUERY_GENERATION",
|
|
||||||
"rag.query_generation.enable",
|
|
||||||
os.environ.get("ENABLE_RAG_QUERY_GENERATION", "False").lower() == "true",
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_RAG_QUERY_GENERATION_TEMPLATE = """Given the user's message and interaction history, decide if a file search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt.
|
|
||||||
User Message:
|
|
||||||
{{prompt:end:4000}}
|
|
||||||
Interaction History:
|
|
||||||
{{MESSAGES:END:6}}
|
|
||||||
Search Query:"""
|
|
||||||
|
|
||||||
|
|
||||||
RAG_QUERY_GENERATION_TEMPLATE = PersistentConfig(
|
|
||||||
"RAG_QUERY_GENERATION_TEMPLATE",
|
|
||||||
"rag.query_generation.template",
|
|
||||||
os.environ.get("RAG_QUERY_GENERATION_TEMPLATE", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
|
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
|
||||||
TIKTOKEN_ENCODING_NAME = PersistentConfig(
|
TIKTOKEN_ENCODING_NAME = PersistentConfig(
|
||||||
"TIKTOKEN_ENCODING_NAME",
|
"TIKTOKEN_ENCODING_NAME",
|
||||||
|
@ -78,11 +78,13 @@ from open_webui.config import (
|
|||||||
ENV,
|
ENV,
|
||||||
FRONTEND_BUILD_DIR,
|
FRONTEND_BUILD_DIR,
|
||||||
OAUTH_PROVIDERS,
|
OAUTH_PROVIDERS,
|
||||||
ENABLE_SEARCH_QUERY,
|
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
||||||
STATIC_DIR,
|
STATIC_DIR,
|
||||||
TASK_MODEL,
|
TASK_MODEL,
|
||||||
TASK_MODEL_EXTERNAL,
|
TASK_MODEL_EXTERNAL,
|
||||||
|
ENABLE_SEARCH_QUERY_GENERATION,
|
||||||
|
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
|
QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
|||||||
from open_webui.utils.task import (
|
from open_webui.utils.task import (
|
||||||
moa_response_generation_template,
|
moa_response_generation_template,
|
||||||
tags_generation_template,
|
tags_generation_template,
|
||||||
search_query_generation_template,
|
query_generation_template,
|
||||||
emoji_generation_template,
|
emoji_generation_template,
|
||||||
title_generation_template,
|
title_generation_template,
|
||||||
tools_function_calling_generation_template,
|
tools_function_calling_generation_template,
|
||||||
@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
|||||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
||||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
)
|
|
||||||
|
|
||||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
|
|||||||
return body, {"contexts": contexts, "citations": citations}
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
async def chat_completion_files_handler(
|
||||||
|
body: dict, user: UserModel
|
||||||
|
) -> tuple[dict, dict[str, list]]:
|
||||||
contexts = []
|
contexts = []
|
||||||
citations = []
|
citations = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
queries_response = await generate_queries(
|
||||||
|
{
|
||||||
|
"model": body["model"],
|
||||||
|
"messages": body["messages"],
|
||||||
|
"type": "retrieval",
|
||||||
|
},
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
queries_response = queries_response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
queries_response = json.loads(queries_response)
|
||||||
|
except Exception as e:
|
||||||
|
queries_response = {"queries": []}
|
||||||
|
|
||||||
|
queries = queries_response.get("queries", [])
|
||||||
|
except Exception as e:
|
||||||
|
queries = []
|
||||||
|
|
||||||
|
if len(queries) == 0:
|
||||||
|
queries = [get_last_user_message(body["messages"])]
|
||||||
|
|
||||||
|
print(f"{queries=}")
|
||||||
|
|
||||||
if files := body.get("metadata", {}).get("files", None):
|
if files := body.get("metadata", {}).get("files", None):
|
||||||
contexts, citations = get_rag_context(
|
contexts, citations = get_rag_context(
|
||||||
files=files,
|
files=files,
|
||||||
messages=body["messages"],
|
queries=queries,
|
||||||
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
|
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
|
||||||
k=retrieval_app.state.config.TOP_K,
|
k=retrieval_app.state.config.TOP_K,
|
||||||
reranking_function=retrieval_app.state.sentence_transformer_rf,
|
reranking_function=retrieval_app.state.sentence_transformer_rf,
|
||||||
@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_files_handler(body)
|
body, flags = await chat_completion_files_handler(body, user)
|
||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
citations.extend(flags.get("citations", []))
|
citations.extend(flags.get("citations", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
|
|||||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||||
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
||||||
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
|
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
|
|||||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||||
ENABLE_TAGS_GENERATION: bool
|
ENABLE_TAGS_GENERATION: bool
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
ENABLE_SEARCH_QUERY_GENERATION: bool
|
||||||
ENABLE_SEARCH_QUERY: bool
|
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||||
|
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||||
|
|
||||||
|
|
||||||
@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|||||||
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||||
)
|
)
|
||||||
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
||||||
|
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
||||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
form_data.ENABLE_SEARCH_QUERY_GENERATION
|
||||||
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
)
|
||||||
|
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
||||||
|
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||||
|
)
|
||||||
|
|
||||||
|
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||||
|
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
)
|
)
|
||||||
app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
|
|
||||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
)
|
)
|
||||||
@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|||||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||||
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
|
||||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
"ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||||
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
|
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/task/query/completions")
|
@app.post("/api/task/queries/completions")
|
||||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
|
||||||
print("generate_search_query")
|
print("generate_queries")
|
||||||
if not app.state.config.ENABLE_SEARCH_QUERY:
|
type = form_data.get("type")
|
||||||
raise HTTPException(
|
if type == "web_search":
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
||||||
detail=f"Search query generation is disabled",
|
raise HTTPException(
|
||||||
)
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Search query generation is disabled",
|
||||||
|
)
|
||||||
|
elif type == "retrieval":
|
||||||
|
if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Query generation is disabled",
|
||||||
|
)
|
||||||
|
|
||||||
model_list = await get_all_models()
|
model_list = await get_all_models()
|
||||||
models = {model["id"]: model for model in model_list}
|
models = {model["id"]: model for model in model_list}
|
||||||
@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|||||||
|
|
||||||
model = models[task_model_id]
|
model = models[task_model_id]
|
||||||
|
|
||||||
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
||||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
else:
|
else:
|
||||||
template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}.
|
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
User Message:
|
content = query_generation_template(
|
||||||
{{prompt:end:4000}}
|
|
||||||
|
|
||||||
Interaction History:
|
|
||||||
{{MESSAGES:END:6}}
|
|
||||||
|
|
||||||
Search Query:"""
|
|
||||||
|
|
||||||
content = search_query_generation_template(
|
|
||||||
template, form_data["messages"], {"name": user.name}
|
template, form_data["messages"], {"name": user.name}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1851,13 +1887,6 @@ Search Query:"""
|
|||||||
"model": task_model_id,
|
"model": task_model_id,
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
**(
|
|
||||||
{"max_tokens": 30}
|
|
||||||
if models[task_model_id]["owned_by"] == "ollama"
|
|
||||||
else {
|
|
||||||
"max_completion_tokens": 30,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
|
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
|
||||||
}
|
}
|
||||||
log.debug(payload)
|
log.debug(payload)
|
||||||
|
@ -163,7 +163,7 @@ def emoji_generation_template(
|
|||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
def search_query_generation_template(
|
def query_generation_template(
|
||||||
template: str, messages: list[dict], user: Optional[dict] = None
|
template: str, messages: list[dict], user: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
prompt = get_last_user_message(messages)
|
prompt = get_last_user_message(messages)
|
||||||
|
@ -348,15 +348,16 @@ export const generateEmoji = async (
|
|||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generateSearchQuery = async (
|
export const generateQueries = async (
|
||||||
token: string = '',
|
token: string = '',
|
||||||
model: string,
|
model: string,
|
||||||
messages: object[],
|
messages: object[],
|
||||||
prompt: string
|
prompt: string,
|
||||||
|
type?: string = 'web_search'
|
||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -366,7 +367,8 @@ export const generateSearchQuery = async (
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: model,
|
model: model,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
prompt: prompt
|
prompt: prompt,
|
||||||
|
type: type
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
@ -385,7 +387,40 @@ export const generateSearchQuery = async (
|
|||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
|
|
||||||
|
try {
|
||||||
|
// Step 1: Safely extract the response string
|
||||||
|
const response = res?.choices[0]?.message?.content ?? '';
|
||||||
|
|
||||||
|
// Step 2: Attempt to fix common JSON format issues like single quotes
|
||||||
|
const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON
|
||||||
|
|
||||||
|
// Step 3: Find the relevant JSON block within the response
|
||||||
|
const jsonStartIndex = sanitizedResponse.indexOf('{');
|
||||||
|
const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
|
||||||
|
|
||||||
|
// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
|
||||||
|
if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
|
||||||
|
const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
|
||||||
|
|
||||||
|
// Step 5: Parse the JSON block
|
||||||
|
const parsed = JSON.parse(jsonResponse);
|
||||||
|
|
||||||
|
// Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array
|
||||||
|
if (parsed && parsed.queries) {
|
||||||
|
return Array.isArray(parsed.queries) ? parsed.queries : [];
|
||||||
|
} else {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no valid JSON block found, return an empty array
|
||||||
|
return [];
|
||||||
|
} catch (e) {
|
||||||
|
// Catch and safely return empty array on any parsing errors
|
||||||
|
console.error('Failed to parse response: ', e);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generateMoACompletion = async (
|
export const generateMoACompletion = async (
|
||||||
|
@ -26,8 +26,9 @@
|
|||||||
TITLE_GENERATION_PROMPT_TEMPLATE: '',
|
TITLE_GENERATION_PROMPT_TEMPLATE: '',
|
||||||
TAGS_GENERATION_PROMPT_TEMPLATE: '',
|
TAGS_GENERATION_PROMPT_TEMPLATE: '',
|
||||||
ENABLE_TAGS_GENERATION: true,
|
ENABLE_TAGS_GENERATION: true,
|
||||||
ENABLE_SEARCH_QUERY: true,
|
ENABLE_SEARCH_QUERY_GENERATION: true,
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
|
ENABLE_RETRIEVAL_QUERY_GENERATION: true,
|
||||||
|
QUERY_GENERATION_PROMPT_TEMPLATE: ''
|
||||||
};
|
};
|
||||||
|
|
||||||
let promptSuggestions = [];
|
let promptSuggestions = [];
|
||||||
@ -164,31 +165,35 @@
|
|||||||
|
|
||||||
<hr class=" dark:border-gray-850 my-3" />
|
<hr class=" dark:border-gray-850 my-3" />
|
||||||
|
|
||||||
|
<div class="my-3 flex w-full items-center justify-between">
|
||||||
|
<div class=" self-center text-xs font-medium">
|
||||||
|
{$i18n.t('Enable Retrieval Query Generation')}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Switch bind:state={taskConfig.ENABLE_RETRIEVAL_QUERY_GENERATION} />
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="my-3 flex w-full items-center justify-between">
|
<div class="my-3 flex w-full items-center justify-between">
|
||||||
<div class=" self-center text-xs font-medium">
|
<div class=" self-center text-xs font-medium">
|
||||||
{$i18n.t('Enable Web Search Query Generation')}
|
{$i18n.t('Enable Web Search Query Generation')}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY} />
|
<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY_GENERATION} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if taskConfig.ENABLE_SEARCH_QUERY}
|
<div class="">
|
||||||
<div class="">
|
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Query Generation Prompt')}</div>
|
||||||
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
|
|
||||||
|
|
||||||
<Tooltip
|
<Tooltip
|
||||||
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
|
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
|
||||||
placement="top-start"
|
placement="top-start"
|
||||||
>
|
>
|
||||||
<Textarea
|
<Textarea
|
||||||
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
|
bind:value={taskConfig.QUERY_GENERATION_PROMPT_TEMPLATE}
|
||||||
placeholder={$i18n.t(
|
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
|
||||||
'Leave empty to use the default prompt, or enter a custom prompt'
|
/>
|
||||||
)}
|
</Tooltip>
|
||||||
/>
|
</div>
|
||||||
</Tooltip>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<hr class=" dark:border-gray-850 my-3" />
|
<hr class=" dark:border-gray-850 my-3" />
|
||||||
|
@ -66,7 +66,7 @@
|
|||||||
import {
|
import {
|
||||||
chatCompleted,
|
chatCompleted,
|
||||||
generateTitle,
|
generateTitle,
|
||||||
generateSearchQuery,
|
generateQueries,
|
||||||
chatAction,
|
chatAction,
|
||||||
generateMoACompletion,
|
generateMoACompletion,
|
||||||
generateTags
|
generateTags
|
||||||
@ -2047,17 +2047,17 @@
|
|||||||
history.messages[responseMessageId] = responseMessage;
|
history.messages[responseMessageId] = responseMessage;
|
||||||
|
|
||||||
const prompt = userMessage.content;
|
const prompt = userMessage.content;
|
||||||
let searchQuery = await generateSearchQuery(
|
let queries = await generateQueries(
|
||||||
localStorage.token,
|
localStorage.token,
|
||||||
model,
|
model,
|
||||||
messages.filter((message) => message?.content?.trim()),
|
messages.filter((message) => message?.content?.trim()),
|
||||||
prompt
|
prompt
|
||||||
).catch((error) => {
|
).catch((error) => {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
return prompt;
|
return [];
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!searchQuery || searchQuery == '') {
|
if (queries.length === 0) {
|
||||||
responseMessage.statusHistory.push({
|
responseMessage.statusHistory.push({
|
||||||
done: true,
|
done: true,
|
||||||
error: true,
|
error: true,
|
||||||
@ -2068,6 +2068,8 @@
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const searchQuery = queries[0];
|
||||||
|
|
||||||
responseMessage.statusHistory.push({
|
responseMessage.statusHistory.push({
|
||||||
done: false,
|
done: false,
|
||||||
action: 'web_search',
|
action: 'web_search',
|
||||||
|
Loading…
Reference in New Issue
Block a user