mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 06:42:47 +00:00
refac: web search
This commit is contained in:
parent
a074991d3a
commit
6b25139d4f
@ -856,6 +856,7 @@ async def chat_completion(
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
|
@ -1238,7 +1238,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
|
||||
|
||||
@router.post("/process/web/search")
|
||||
def process_web_search(
|
||||
async def process_web_search(
|
||||
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
@ -1256,9 +1256,11 @@ def process_web_search(
|
||||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
||||
)
|
||||
|
||||
log.debug(f"web_results: {web_results}")
|
||||
|
||||
try:
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
if collection_name == "" or collection_name is None:
|
||||
collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
|
||||
:63
|
||||
]
|
||||
@ -1269,8 +1271,7 @@ def process_web_search(
|
||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
)
|
||||
docs = loader.aload()
|
||||
|
||||
docs = loader.load()
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
|
||||
return {
|
||||
|
@ -29,6 +29,7 @@ from open_webui.routers.tasks import (
|
||||
generate_title,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
|
||||
@ -333,6 +334,149 @@ async def chat_completion_tools_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_web_search_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
event_emitter = extra_params["__event_emitter__"]
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Generating search query",
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
messages = form_data["messages"]
|
||||
user_message = get_last_user_message(messages)
|
||||
|
||||
queries = []
|
||||
try:
|
||||
res = await generate_queries(
|
||||
request,
|
||||
{
|
||||
"model": form_data["model"],
|
||||
"messages": messages,
|
||||
"prompt": user_message,
|
||||
"type": "web_search",
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
response = res["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
bracket_start = response.find("{")
|
||||
bracket_end = response.rfind("}") + 1
|
||||
|
||||
if bracket_start == -1 or bracket_end == -1:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
response = response[bracket_start:bracket_end]
|
||||
queries = json.loads(response)
|
||||
queries = queries.get("queries", [])
|
||||
except Exception as e:
|
||||
queries = [response]
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
queries = [user_message]
|
||||
|
||||
if len(queries) == 0:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search query generated",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
searchQuery = queries[0]
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
results = await process_web_search(
|
||||
request,
|
||||
SearchForm(
|
||||
**{
|
||||
"query": searchQuery,
|
||||
}
|
||||
),
|
||||
user,
|
||||
)
|
||||
|
||||
if results:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "Searched {{count}} sites",
|
||||
"query": searchQuery,
|
||||
"urls": results["filenames"],
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
files = form_data.get("files", [])
|
||||
files.append(
|
||||
{
|
||||
"collection_name": results["collection_name"],
|
||||
"name": searchQuery,
|
||||
"type": "web_search_results",
|
||||
"urls": results["filenames"],
|
||||
}
|
||||
)
|
||||
form_data["files"] = files
|
||||
else:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": "No search results found",
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"action": "web_search",
|
||||
"description": 'Error searching "{{searchQuery}}"',
|
||||
"query": searchQuery,
|
||||
"done": True,
|
||||
"error": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
request: Request, body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
knowledge_files = []
|
||||
for item in model_knowledge:
|
||||
print(item)
|
||||
if item.get("collection_name"):
|
||||
knowledge_files.append(
|
||||
{
|
||||
@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
files.extend(knowledge_files)
|
||||
form_data["files"] = files
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
|
@ -1419,11 +1419,8 @@
|
||||
const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
|
||||
|
||||
scrollToBottom();
|
||||
if (webSearchEnabled) {
|
||||
await getWebSearchResults(model.id, parentId, responseMessageId);
|
||||
}
|
||||
|
||||
await sendPromptSocket(model, responseMessageId, _chatId);
|
||||
|
||||
if (chatEventEmitter) clearInterval(chatEventEmitter);
|
||||
} else {
|
||||
toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
|
||||
@ -1533,8 +1530,12 @@
|
||||
: undefined
|
||||
},
|
||||
|
||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||
files: files.length > 0 ? files : undefined,
|
||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||
features: {
|
||||
web_search: webSearchEnabled
|
||||
},
|
||||
|
||||
session_id: $socket?.id,
|
||||
chat_id: $chatId,
|
||||
id: responseMessageId,
|
||||
@ -1751,94 +1752,6 @@
|
||||
}
|
||||
};
|
||||
|
||||
const getWebSearchResults = async (
|
||||
model: string,
|
||||
parentId: string,
|
||||
responseMessageId: string
|
||||
) => {
|
||||
// TODO: move this to the backend
|
||||
const responseMessage = history.messages[responseMessageId];
|
||||
const userMessage = history.messages[parentId];
|
||||
const messages = createMessagesList(history.currentId);
|
||||
|
||||
responseMessage.statusHistory = [
|
||||
{
|
||||
done: false,
|
||||
action: 'web_search',
|
||||
description: $i18n.t('Generating search query')
|
||||
}
|
||||
];
|
||||
history.messages[responseMessageId] = responseMessage;
|
||||
|
||||
const prompt = userMessage.content;
|
||||
let queries = await generateQueries(
|
||||
localStorage.token,
|
||||
model,
|
||||
messages.filter((message) => message?.content?.trim()),
|
||||
prompt
|
||||
).catch((error) => {
|
||||
console.log(error);
|
||||
return [prompt];
|
||||
});
|
||||
|
||||
if (queries.length === 0) {
|
||||
responseMessage.statusHistory.push({
|
||||
done: true,
|
||||
error: true,
|
||||
action: 'web_search',
|
||||
description: $i18n.t('No search query generated')
|
||||
});
|
||||
history.messages[responseMessageId] = responseMessage;
|
||||
return;
|
||||
}
|
||||
|
||||
const searchQuery = queries[0];
|
||||
|
||||
responseMessage.statusHistory.push({
|
||||
done: false,
|
||||
action: 'web_search',
|
||||
description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery })
|
||||
});
|
||||
history.messages[responseMessageId] = responseMessage;
|
||||
|
||||
const results = await processWebSearch(localStorage.token, searchQuery).catch((error) => {
|
||||
console.log(error);
|
||||
toast.error(error);
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
if (results) {
|
||||
responseMessage.statusHistory.push({
|
||||
done: true,
|
||||
action: 'web_search',
|
||||
description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }),
|
||||
query: searchQuery,
|
||||
urls: results.filenames
|
||||
});
|
||||
|
||||
if (responseMessage?.files ?? undefined === undefined) {
|
||||
responseMessage.files = [];
|
||||
}
|
||||
|
||||
responseMessage.files.push({
|
||||
collection_name: results.collection_name,
|
||||
name: searchQuery,
|
||||
type: 'web_search_results',
|
||||
urls: results.filenames
|
||||
});
|
||||
history.messages[responseMessageId] = responseMessage;
|
||||
} else {
|
||||
responseMessage.statusHistory.push({
|
||||
done: true,
|
||||
error: true,
|
||||
action: 'web_search',
|
||||
description: 'No search results found'
|
||||
});
|
||||
history.messages[responseMessageId] = responseMessage;
|
||||
}
|
||||
};
|
||||
|
||||
const initChatHandler = async () => {
|
||||
if (!$temporaryChatEnabled) {
|
||||
chat = await createNewChat(localStorage.token, {
|
||||
|
@ -535,7 +535,14 @@
|
||||
? 'shimmer'
|
||||
: ''} text-base line-clamp-1 text-wrap"
|
||||
>
|
||||
{status?.description}
|
||||
<!-- $i18n.t('Searched {{count}} sites') -->
|
||||
{#if status?.description.includes('{{count}}')}
|
||||
{$i18n.t(status?.description, {
|
||||
count: status?.urls.length
|
||||
})}
|
||||
{:else}
|
||||
{$i18n.t(status?.description)}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</WebSearchResults>
|
||||
@ -558,7 +565,14 @@
|
||||
? 'shimmer'
|
||||
: ''} text-gray-500 dark:text-gray-500 text-base line-clamp-1 text-wrap"
|
||||
>
|
||||
{status?.description}
|
||||
<!-- $i18n.t(`Searching "{{searchQuery}}"`) -->
|
||||
{#if status?.description.includes('{{searchQuery}}')}
|
||||
{$i18n.t(status?.description, {
|
||||
searchQuery: status?.query
|
||||
})}
|
||||
{:else}
|
||||
{$i18n.t(status?.description)}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
Loading…
Reference in New Issue
Block a user