From dbb67a12cac22e6cb3466132fb42a6a7582e575b Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 19 Nov 2024 02:24:32 -0800 Subject: [PATCH] enh: retrieval query generation --- backend/open_webui/apps/retrieval/utils.py | 65 +++++---- backend/open_webui/config.py | 67 +++++----- backend/open_webui/main.py | 123 +++++++++++------- backend/open_webui/utils/task.py | 2 +- src/lib/apis/index.ts | 45 ++++++- .../admin/Settings/Interface.svelte | 43 +++--- src/lib/components/chat/Chat.svelte | 10 +- 7 files changed, 217 insertions(+), 138 deletions(-) diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 77d97814c..6d87c98e3 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -177,35 +177,34 @@ def merge_and_sort_query_results( def query_collection( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, ) -> dict: - results = [] - query_embedding = embedding_function(query) - - for collection_name in collection_names: - if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + for query in queries: + query_embedding = embedding_function(query) + for collection_name in collection_names: + if collection_name: + try: + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass return merge_and_sort_query_results(results, k=k) def query_collection_with_hybrid_search( collection_names: list[str], - query: str, + queries: list[str], embedding_function, k: int, reranking_function, @@ -215,15 +214,16 @@ def query_collection_with_hybrid_search( error = False for collection_name in collection_names: try: - result = query_doc_with_hybrid_search( - collection_name=collection_name, - query=query, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - r=r, - ) - results.append(result) + for query in queries: + result = query_doc_with_hybrid_search( + collection_name=collection_name, + query=query, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + results.append(result) except Exception as e: log.exception( "Error when querying the collection with " f"hybrid_search: {e}" @@ -309,15 +309,14 @@ def get_embedding_function( def get_rag_context( files, - messages, + queries, embedding_function, k, reranking_function, r, hybrid_search, ): - log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}") - query = get_last_user_message(messages) + log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") extracted_collections = [] relevant_contexts = [] @@ -359,7 +358,7 @@ def get_rag_context( try: context = query_collection_with_hybrid_search( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, reranking_function=reranking_function, @@ -374,7 +373,7 @@ def get_rag_context( if (not hybrid_search) or (context is None): context = query_collection( collection_names=collection_names, - query=query, + queries=queries, embedding_function=embedding_function, k=k, ) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 621ebf35b..c33895396 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -941,19 +941,49 @@ ENABLE_TAGS_GENERATION = PersistentConfig( os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", ) -ENABLE_SEARCH_QUERY = PersistentConfig( - "ENABLE_SEARCH_QUERY", - "task.search.enable", - os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true", + +ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( + "ENABLE_SEARCH_QUERY_GENERATION", + "task.query.search.enable", + os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true", +) + +ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig( + "ENABLE_RETRIEVAL_QUERY_GENERATION", + "task.query.retrieval.enable", + os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true", ) -SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", - "task.search.prompt_template", - os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""), +QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "QUERY_GENERATION_PROMPT_TEMPLATE", + "task.query.prompt_template", + os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""), ) +DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: +Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list. + +### Guidelines: +- Respond exclusively with a JSON object. +- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise. +- If no search query is necessary, output should be: { "queries": [] } +- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required. +- Be concise, focusing strictly on composing search queries with no additional commentary or text. +- When in doubt, prefer to suggest a search for comprehensiveness. +- Today's date is: {{CURRENT_DATE}} + +### Output: +JSON format: { + "queries": ["query1", "query2"] +} + +### Chat History: + +{{MESSAGES:END:6}} + +""" + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", @@ -1127,27 +1157,6 @@ RAG_TEXT_SPLITTER = PersistentConfig( ) -ENABLE_RAG_QUERY_GENERATION = PersistentConfig( - "ENABLE_RAG_QUERY_GENERATION", - "rag.query_generation.enable", - os.environ.get("ENABLE_RAG_QUERY_GENERATION", "False").lower() == "true", -) - -DEFAULT_RAG_QUERY_GENERATION_TEMPLATE = """Given the user's message and interaction history, decide if a file search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. -User Message: -{{prompt:end:4000}} -Interaction History: -{{MESSAGES:END:6}} -Search Query:""" - - -RAG_QUERY_GENERATION_TEMPLATE = PersistentConfig( - "RAG_QUERY_GENERATION_TEMPLATE", - "rag.query_generation.template", - os.environ.get("RAG_QUERY_GENERATION_TEMPLATE", ""), -) - - TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") TIKTOKEN_ENCODING_NAME = PersistentConfig( "TIKTOKEN_ENCODING_NAME", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6f2c7ac42..04c86395a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -78,11 +78,13 @@ from open_webui.config import ( ENV, FRONTEND_BUILD_DIR, OAUTH_PROVIDERS, - ENABLE_SEARCH_QUERY, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, STATIC_DIR, TASK_MODEL, TASK_MODEL_EXTERNAL, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_RETRIEVAL_QUERY_GENERATION, + QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( moa_response_generation_template, tags_generation_template, - search_query_generation_template, + query_generation_template, emoji_generation_template, title_generation_template, tools_function_calling_generation_template, @@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE -app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY -app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE -) +app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION +app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION +app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE @@ -492,14 +493,41 @@ async def chat_completion_tools_handler( return body, {"contexts": contexts, "citations": citations} -async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: +async def chat_completion_files_handler( + body: dict, user: UserModel +) -> tuple[dict, dict[str, list]]: contexts = [] citations = [] + try: + queries_response = await generate_queries( + { + "model": body["model"], + "messages": body["messages"], + "type": "retrieval", + }, + user, + ) + queries_response = queries_response["choices"][0]["message"]["content"] + + try: + queries_response = json.loads(queries_response) + except Exception as e: + queries_response = {"queries": []} + + queries = queries_response.get("queries", []) + except Exception as e: + queries = [] + + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + + print(f"{queries=}") + if files := body.get("metadata", {}).get("files", None): contexts, citations = get_rag_context( files=files, - messages=body["messages"], + queries=queries, embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, k=retrieval_app.state.config.TOP_K, reranking_function=retrieval_app.state.sentence_transformer_rf, @@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): log.exception(e) try: - body, flags = await chat_completion_files_handler(body) + body, flags = await chat_completion_files_handler(body, user) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)): "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel): TITLE_GENERATION_PROMPT_TEMPLATE: str TAGS_GENERATION_PROMPT_TEMPLATE: str ENABLE_TAGS_GENERATION: bool - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str - ENABLE_SEARCH_QUERY: bool + ENABLE_SEARCH_QUERY_GENERATION: bool + ENABLE_RETRIEVAL_QUERY_GENERATION: bool + QUERY_GENERATION_PROMPT_TEMPLATE: str TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u form_data.TAGS_GENERATION_PROMPT_TEMPLATE ) app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION - - app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( - form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( + form_data.ENABLE_SEARCH_QUERY_GENERATION + ) + app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( + form_data.ENABLE_RETRIEVAL_QUERY_GENERATION + ) + + app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.QUERY_GENERATION_PROMPT_TEMPLATE ) - app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) @@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, - "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - "ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, + "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, + "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, + "QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } return await generate_chat_completions(form_data=payload, user=user) -@app.post("/api/task/query/completions") -async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): - print("generate_search_query") - if not app.state.config.ENABLE_SEARCH_QUERY: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Search query generation is disabled", - ) +@app.post("/api/task/queries/completions") +async def generate_queries(form_data: dict, user=Depends(get_verified_user)): + print("generate_queries") + type = form_data.get("type") + if type == "web_search": + if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search query generation is disabled", + ) + elif type == "retrieval": + if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Query generation is disabled", + ) model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) model = models[task_model_id] - if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": - template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "": + template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE else: - template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}. + template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE -User Message: -{{prompt:end:4000}} - -Interaction History: -{{MESSAGES:END:6}} - -Search Query:""" - - content = search_query_generation_template( + content = query_generation_template( template, form_data["messages"], {"name": user.name} ) @@ -1851,13 +1887,6 @@ Search Query:""" "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - **( - {"max_tokens": 30} - if models[task_model_id]["owned_by"] == "ollama" - else { - "max_completion_tokens": 30, - } - ), "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, } log.debug(payload) diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 799cca11a..28b07da37 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -163,7 +163,7 @@ def emoji_generation_template( return template -def search_query_generation_template( +def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: prompt = get_last_user_message(messages) diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index d22923670..56d4d9bc6 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -348,15 +348,16 @@ export const generateEmoji = async ( return null; }; -export const generateSearchQuery = async ( +export const generateQueries = async ( token: string = '', model: string, messages: object[], - prompt: string + prompt: string, + type?: string = 'web_search' ) => { let error = null; - const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, { method: 'POST', headers: { Accept: 'application/json', @@ -366,7 +367,8 @@ export const generateSearchQuery = async ( body: JSON.stringify({ model: model, messages: messages, - prompt: prompt + prompt: prompt, + type: type }) }) .then(async (res) => { @@ -385,7 +387,40 @@ export const generateSearchQuery = async ( throw error; } - return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; + + try { + // Step 1: Safely extract the response string + const response = res?.choices[0]?.message?.content ?? ''; + + // Step 2: Attempt to fix common JSON format issues like single quotes + const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON + + // Step 3: Find the relevant JSON block within the response + const jsonStartIndex = sanitizedResponse.indexOf('{'); + const jsonEndIndex = sanitizedResponse.lastIndexOf('}'); + + // Step 4: Check if we found a valid JSON block (with both `{` and `}`) + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array + if (parsed && parsed.queries) { + return Array.isArray(parsed.queries) ? parsed.queries : []; + } else { + return []; + } + } + + // If no valid JSON block found, return an empty array + return []; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return []; + } }; export const generateMoACompletion = async ( diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index b79ce0747..3043ffae4 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -26,8 +26,9 @@ TITLE_GENERATION_PROMPT_TEMPLATE: '', TAGS_GENERATION_PROMPT_TEMPLATE: '', ENABLE_TAGS_GENERATION: true, - ENABLE_SEARCH_QUERY: true, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: '' + ENABLE_SEARCH_QUERY_GENERATION: true, + ENABLE_RETRIEVAL_QUERY_GENERATION: true, + QUERY_GENERATION_PROMPT_TEMPLATE: '' }; let promptSuggestions = []; @@ -164,31 +165,35 @@
+
+
+ {$i18n.t('Enable Retrieval Query Generation')} +
+ + +
+
{$i18n.t('Enable Web Search Query Generation')}
- +
- {#if taskConfig.ENABLE_SEARCH_QUERY} -
-
{$i18n.t('Search Query Generation Prompt')}
+
+
{$i18n.t('Query Generation Prompt')}
- -