From 6b25139d4fc3f8fa271b986c2093d837b1838ad1 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 24 Dec 2024 17:52:57 -0700 Subject: [PATCH] refac: web search --- backend/open_webui/main.py | 1 + backend/open_webui/routers/retrieval.py | 9 +- backend/open_webui/utils/middleware.py | 152 +++++++++++++++++- src/lib/components/chat/Chat.svelte | 99 +----------- .../chat/Messages/ResponseMessage.svelte | 18 ++- 5 files changed, 179 insertions(+), 100 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d1cca2ac4..fdd934cc7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -856,6 +856,7 @@ async def chat_completion( "session_id": form_data.pop("session_id", None), "tool_ids": form_data.get("tool_ids", None), "files": form_data.get("files", None), + "features": form_data.get("features", None), } form_data["metadata"] = metadata diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 6d9cd0758..e31d5c01d 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1238,7 +1238,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: @router.post("/process/web/search") -def process_web_search( +async def process_web_search( request: Request, form_data: SearchForm, user=Depends(get_verified_user) ): try: @@ -1256,9 +1256,11 @@ def process_web_search( detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), ) + log.debug(f"web_results: {web_results}") + try: collection_name = form_data.collection_name - if collection_name == "": + if collection_name == "" or collection_name is None: collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ :63 ] @@ -1269,8 +1271,7 @@ def process_web_search( verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) - docs = loader.aload() - + docs = loader.load() save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 475bb539f..078c1e07e 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -29,6 +29,7 @@ from open_webui.routers.tasks import ( generate_title, generate_chat_tags, ) +from open_webui.routers.retrieval import process_web_search, SearchForm from open_webui.utils.webhook import post_webhook @@ -333,6 +334,149 @@ async def chat_completion_tools_handler( return body, {"sources": sources} +async def chat_web_search_handler( + request: Request, form_data: dict, extra_params: dict, user +): + event_emitter = extra_params["__event_emitter__"] + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "Generating search query", + "done": False, + }, + } + ) + + messages = form_data["messages"] + user_message = get_last_user_message(messages) + + queries = [] + try: + res = await generate_queries( + request, + { + "model": form_data["model"], + "messages": messages, + "prompt": user_message, + "type": "web_search", + }, + user, + ) + + response = res["choices"][0]["message"]["content"] + + try: + bracket_start = response.find("{") + bracket_end = response.rfind("}") + 1 + + if bracket_start == -1 or bracket_end == -1: + raise Exception("No JSON object found in the response") + + response = response[bracket_start:bracket_end] + queries = json.loads(response) + queries = queries.get("queries", []) + except Exception as e: + queries = [response] + + except Exception as e: + log.exception(e) + queries = [user_message] + + if len(queries) == 0: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "No search query generated", + "done": True, + }, + } + ) + return + + searchQuery = queries[0] + + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": 'Searching "{{searchQuery}}"', + "query": searchQuery, + "done": False, + }, + } + ) + + try: + results = await process_web_search( + request, + SearchForm( + **{ + "query": searchQuery, + } + ), + user, + ) + + if results: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "Searched {{count}} sites", + "query": searchQuery, + "urls": results["filenames"], + "done": True, + }, + } + ) + + files = form_data.get("files", []) + files.append( + { + "collection_name": results["collection_name"], + "name": searchQuery, + "type": "web_search_results", + "urls": results["filenames"], + } + ) + form_data["files"] = files + else: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "No search results found", + "query": searchQuery, + "done": True, + "error": True, + }, + } + ) + except Exception as e: + log.exception(e) + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": 'Error searching "{{searchQuery}}"', + "query": searchQuery, + "done": True, + "error": True, + }, + } + ) + + return form_data + + async def chat_completion_files_handler( request: Request, body: dict, user: UserModel ) -> tuple[dict, dict[str, list]]: @@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model): knowledge_files = [] for item in model_knowledge: - print(item) if item.get("collection_name"): knowledge_files.append( { @@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model): files.extend(knowledge_files) form_data["files"] = files + features = form_data.pop("features", None) + if features: + if "web_search" in features and features["web_search"]: + form_data = await chat_web_search_handler( + request, form_data, extra_params, user + ) + try: form_data, flags = await chat_completion_filter_functions_handler( request, form_data, model, extra_params diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 57e154b4d..59afbb793 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1419,11 +1419,8 @@ const chatEventEmitter = await getChatEventEmitter(model.id, _chatId); scrollToBottom(); - if (webSearchEnabled) { - await getWebSearchResults(model.id, parentId, responseMessageId); - } - await sendPromptSocket(model, responseMessageId, _chatId); + if (chatEventEmitter) clearInterval(chatEventEmitter); } else { toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); @@ -1533,8 +1530,12 @@ : undefined }, - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + features: { + web_search: webSearchEnabled + }, + session_id: $socket?.id, chat_id: $chatId, id: responseMessageId, @@ -1751,94 +1752,6 @@ } }; - const getWebSearchResults = async ( - model: string, - parentId: string, - responseMessageId: string - ) => { - // TODO: move this to the backend - const responseMessage = history.messages[responseMessageId]; - const userMessage = history.messages[parentId]; - const messages = createMessagesList(history.currentId); - - responseMessage.statusHistory = [ - { - done: false, - action: 'web_search', - description: $i18n.t('Generating search query') - } - ]; - history.messages[responseMessageId] = responseMessage; - - const prompt = userMessage.content; - let queries = await generateQueries( - localStorage.token, - model, - messages.filter((message) => message?.content?.trim()), - prompt - ).catch((error) => { - console.log(error); - return [prompt]; - }); - - if (queries.length === 0) { - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: $i18n.t('No search query generated') - }); - history.messages[responseMessageId] = responseMessage; - return; - } - - const searchQuery = queries[0]; - - responseMessage.statusHistory.push({ - done: false, - action: 'web_search', - description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) - }); - history.messages[responseMessageId] = responseMessage; - - const results = await processWebSearch(localStorage.token, searchQuery).catch((error) => { - console.log(error); - toast.error(error); - - return null; - }); - - if (results) { - responseMessage.statusHistory.push({ - done: true, - action: 'web_search', - description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), - query: searchQuery, - urls: results.filenames - }); - - if (responseMessage?.files ?? undefined === undefined) { - responseMessage.files = []; - } - - responseMessage.files.push({ - collection_name: results.collection_name, - name: searchQuery, - type: 'web_search_results', - urls: results.filenames - }); - history.messages[responseMessageId] = responseMessage; - } else { - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search results found' - }); - history.messages[responseMessageId] = responseMessage; - } - }; - const initChatHandler = async () => { if (!$temporaryChatEnabled) { chat = await createNewChat(localStorage.token, { diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index b8ca4fe79..cd241488e 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -535,7 +535,14 @@ ? 'shimmer' : ''} text-base line-clamp-1 text-wrap" > - {status?.description} + + {#if status?.description.includes('{{count}}')} + {$i18n.t(status?.description, { + count: status?.urls.length + })} + {:else} + {$i18n.t(status?.description)} + {/if} @@ -558,7 +565,14 @@ ? 'shimmer' : ''} text-gray-500 dark:text-gray-500 text-base line-clamp-1 text-wrap" > - {status?.description} + + {#if status?.description.includes('{{searchQuery}}')} + {$i18n.t(status?.description, { + searchQuery: status?.query + })} + {:else} + {$i18n.t(status?.description)} + {/if} {/if}