mirror of
https://github.com/open-webui/open-webui
synced 2025-01-18 00:30:51 +00:00
feat: prototype frontend web search integration
This commit is contained in:
parent
619c2f9c71
commit
2660a6e5b8
@ -93,6 +93,7 @@ from config import (
|
||||
CHUNK_OVERLAP,
|
||||
RAG_TEMPLATE,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
@ -538,18 +539,23 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
|
||||
# Check if the URL is valid
|
||||
if not validate_url(url):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(url, verify_ssl=verify_ssl)
|
||||
return WebBaseLoader(
|
||||
url,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
||||
|
||||
def validate_url(url: Union[str, Sequence[str]]):
|
||||
if isinstance(url, str):
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
@ -593,7 +599,7 @@ def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
|
||||
)
|
||||
urls = [result.link for result in web_results]
|
||||
loader = get_web_loader(urls)
|
||||
data = loader.load()
|
||||
data = loader.aload()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": api_key,
|
||||
}
|
||||
params = {"q": query, "count": WEB_SEARCH_RESULT_COUNT}
|
||||
params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
)
|
||||
for result in results[:WEB_SEARCH_RESULT_COUNT]
|
||||
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -27,7 +27,7 @@ def search_google_pse(
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": WEB_SEARCH_RESULT_COUNT,
|
||||
"num": RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
}
|
||||
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -40,5 +40,5 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||
)
|
||||
for result in sorted_results[:WEB_SEARCH_RESULT_COUNT]
|
||||
for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:WEB_SEARCH_RESULT_COUNT]
|
||||
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -39,5 +39,5 @@ def search_serpstack(
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
)
|
||||
for result in results[:WEB_SEARCH_RESULT_COUNT]
|
||||
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
]
|
||||
|
@ -549,7 +549,10 @@ BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
|
||||
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
|
||||
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
|
||||
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
|
||||
WEB_SEARCH_RESULT_COUNT = int(os.getenv("WEB_SEARCH_RESULT_COUNT", "10"))
|
||||
RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "10"))
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
|
||||
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
|
||||
)
|
||||
|
||||
####################################
|
||||
# Transcribe
|
||||
|
@ -318,3 +318,119 @@ export const generateTitle = async (
|
||||
|
||||
return res?.choices[0]?.message?.content ?? 'New Chat';
|
||||
};
|
||||
|
||||
export const generateSearchQuery = async (
|
||||
token: string = '',
|
||||
// template: string,
|
||||
model: string,
|
||||
prompt: string,
|
||||
url: string = OPENAI_API_BASE_URL
|
||||
): Promise<string | undefined> => {
|
||||
let error = null;
|
||||
|
||||
// TODO: Allow users to specify the prompt
|
||||
// template = promptTemplate(template, prompt);
|
||||
|
||||
// Get the current date in the format "January 20, 2024"
|
||||
const currentDate = new Intl.DateTimeFormat('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'long',
|
||||
day: '2-digit'
|
||||
}).format(new Date());
|
||||
const yesterdayDate = new Intl.DateTimeFormat('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'long',
|
||||
day: '2-digit'
|
||||
}).format(new Date());
|
||||
|
||||
// console.log(template);
|
||||
|
||||
const res = await fetch(`${url}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: model,
|
||||
// Few shot prompting
|
||||
messages: [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}.`
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: `Previous Questions:
|
||||
- Who is the president of France?
|
||||
|
||||
Current Question: What about Mexico?`
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'President of Mexico'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: `Previous questions:
|
||||
- When is the next formula 1 grand prix?
|
||||
|
||||
Current Question: Where is it being hosted?`
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'location of next formula 1 grand prix'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Current Question: What type of printhead does the Epson F2270 DTG printer use?'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Epson F2270 DTG printer printhead'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What were the news yesterday?'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: `news ${yesterdayDate}`
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What is the current weather in Paris?'
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: `weather in Paris ${currentDate}`
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: `Current Question: ${prompt}`
|
||||
}
|
||||
],
|
||||
stream: false,
|
||||
// Restricting the max tokens to 30 to avoid long search queries
|
||||
max_tokens: 30
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
if ('detail' in err) {
|
||||
error = err.detail;
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? undefined;
|
||||
};
|
||||
|
@ -507,3 +507,44 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const runWebSearch = async (
|
||||
token: string,
|
||||
query: string,
|
||||
collection_name?: string
|
||||
): Promise<SearchDocument | undefined> => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RAG_API_BASE_URL}/websearch`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
collection_name
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return undefined;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export interface SearchDocument {
|
||||
status: boolean;
|
||||
collection_name: string;
|
||||
filenames: string[];
|
||||
}
|
||||
|
@ -30,8 +30,8 @@
|
||||
getTagsById,
|
||||
updateChatById
|
||||
} from '$lib/apis/chats';
|
||||
import { queryCollection, queryDoc } from '$lib/apis/rag';
|
||||
import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
|
||||
import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag';
|
||||
import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
|
||||
|
||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||
import Messages from '$lib/components/chat/Messages.svelte';
|
||||
@ -55,6 +55,8 @@
|
||||
let selectedModels = [''];
|
||||
let atSelectedModel = '';
|
||||
|
||||
let useWebSearch = false;
|
||||
|
||||
let selectedModelfile = null;
|
||||
$: selectedModelfile =
|
||||
selectedModels.length === 1 &&
|
||||
@ -275,6 +277,39 @@
|
||||
];
|
||||
}
|
||||
|
||||
if (useWebSearch) {
|
||||
// TODO: Toasts are temporary indicators for web search
|
||||
toast.info($i18n.t('Generating search query'));
|
||||
const searchQuery = await generateChatSearchQuery(prompt);
|
||||
if (searchQuery) {
|
||||
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
|
||||
const searchDocUuid = uuidv4();
|
||||
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
|
||||
if (searchDocument) {
|
||||
const parentMessage = history.messages[parentId];
|
||||
if (!parentMessage.files) {
|
||||
parentMessage.files = [];
|
||||
}
|
||||
parentMessage.files.push({
|
||||
collection_name: searchDocument.collection_name,
|
||||
name: searchQuery,
|
||||
type: 'doc',
|
||||
upload_status: true,
|
||||
error: ""
|
||||
});
|
||||
// Find message in messages and update it
|
||||
const messageIndex = messages.findIndex((message) => message.id === parentId);
|
||||
if (messageIndex !== -1) {
|
||||
messages[messageIndex] = parentMessage;
|
||||
}
|
||||
} else {
|
||||
toast.warning($i18n.t('No search results found'));
|
||||
}
|
||||
} else {
|
||||
toast.warning($i18n.t('No search query generated'));
|
||||
}
|
||||
}
|
||||
|
||||
if (model?.external) {
|
||||
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
|
||||
} else if (model) {
|
||||
@ -807,6 +842,30 @@
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Add support for adding all the user's messages as context, and not just the last message
|
||||
const generateChatSearchQuery = async (userPrompt: string) => {
|
||||
const model = $models.find((model) => model.id === selectedModels[0]);
|
||||
|
||||
// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
|
||||
const titleModelId =
|
||||
model?.external ?? false
|
||||
? $settings?.title?.modelExternal ?? selectedModels[0]
|
||||
: $settings?.title?.model ?? selectedModels[0];
|
||||
const titleModel = $models.find((model) => model.id === titleModelId);
|
||||
|
||||
console.log(titleModel);
|
||||
return await generateSearchQuery(
|
||||
localStorage.token,
|
||||
titleModelId,
|
||||
userPrompt,
|
||||
titleModel?.external ?? false
|
||||
? titleModel?.source?.toLowerCase() === 'litellm'
|
||||
? `${LITELLM_API_BASE_URL}/v1`
|
||||
: `${OPENAI_API_BASE_URL}`
|
||||
: `${OLLAMA_API_BASE_URL}/v1`
|
||||
);
|
||||
};
|
||||
|
||||
const setChatTitle = async (_chatId, _title) => {
|
||||
if (_chatId === $chatId) {
|
||||
title = _title;
|
||||
@ -906,6 +965,7 @@
|
||||
bind:prompt
|
||||
bind:autoScroll
|
||||
bind:selectedModel={atSelectedModel}
|
||||
bind:useWebSearch
|
||||
{messages}
|
||||
{submitPrompt}
|
||||
{stopResponse}
|
||||
|
@ -30,7 +30,7 @@
|
||||
getTagsById,
|
||||
updateChatById
|
||||
} from '$lib/apis/chats';
|
||||
import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
|
||||
import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
|
||||
|
||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||
import Messages from '$lib/components/chat/Messages.svelte';
|
||||
@ -43,6 +43,7 @@
|
||||
WEBUI_BASE_URL
|
||||
} from '$lib/constants';
|
||||
import { createOpenAITextStream } from '$lib/apis/streaming';
|
||||
import { runWebSearch } from '$lib/apis/rag';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
@ -59,6 +60,8 @@
|
||||
let selectedModels = [''];
|
||||
let atSelectedModel = '';
|
||||
|
||||
let useWebSearch = false;
|
||||
|
||||
let selectedModelfile = null;
|
||||
|
||||
$: selectedModelfile =
|
||||
@ -287,6 +290,39 @@
|
||||
];
|
||||
}
|
||||
|
||||
if (useWebSearch) {
|
||||
// TODO: Toasts are temporary indicators for web search
|
||||
toast.info($i18n.t('Generating search query'));
|
||||
const searchQuery = await generateChatSearchQuery(prompt);
|
||||
if (searchQuery) {
|
||||
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
|
||||
const searchDocUuid = uuidv4();
|
||||
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
|
||||
if (searchDocument) {
|
||||
const parentMessage = history.messages[parentId];
|
||||
if (!parentMessage.files) {
|
||||
parentMessage.files = [];
|
||||
}
|
||||
parentMessage.files.push({
|
||||
collection_name: searchDocument.collection_name,
|
||||
name: searchQuery,
|
||||
type: 'doc',
|
||||
upload_status: true,
|
||||
error: ""
|
||||
});
|
||||
// Find message in messages and update it
|
||||
const messageIndex = messages.findIndex((message) => message.id === parentId);
|
||||
if (messageIndex !== -1) {
|
||||
messages[messageIndex] = parentMessage;
|
||||
}
|
||||
} else {
|
||||
toast.warning($i18n.t('No search results found'));
|
||||
}
|
||||
} else {
|
||||
toast.warning($i18n.t('No search query generated'));
|
||||
}
|
||||
}
|
||||
|
||||
if (model?.external) {
|
||||
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
|
||||
} else if (model) {
|
||||
@ -819,6 +855,30 @@
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Add support for adding all the user's messages as context, and not just the last message
|
||||
const generateChatSearchQuery = async (userPrompt: string) => {
|
||||
const model = $models.find((model) => model.id === selectedModels[0]);
|
||||
|
||||
// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
|
||||
const titleModelId =
|
||||
model?.external ?? false
|
||||
? $settings?.title?.modelExternal ?? selectedModels[0]
|
||||
: $settings?.title?.model ?? selectedModels[0];
|
||||
const titleModel = $models.find((model) => model.id === titleModelId);
|
||||
|
||||
console.log(titleModel);
|
||||
return await generateSearchQuery(
|
||||
localStorage.token,
|
||||
titleModelId,
|
||||
userPrompt,
|
||||
titleModel?.external ?? false
|
||||
? titleModel?.source?.toLowerCase() === 'litellm'
|
||||
? `${LITELLM_API_BASE_URL}/v1`
|
||||
: `${OPENAI_API_BASE_URL}`
|
||||
: `${OLLAMA_API_BASE_URL}/v1`
|
||||
);
|
||||
};
|
||||
|
||||
const setChatTitle = async (_chatId, _title) => {
|
||||
if (_chatId === $chatId) {
|
||||
title = _title;
|
||||
@ -929,6 +989,7 @@
|
||||
bind:prompt
|
||||
bind:autoScroll
|
||||
bind:selectedModel={atSelectedModel}
|
||||
bind:useWebSearch
|
||||
suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
|
||||
{messages}
|
||||
{submitPrompt}
|
||||
|
Loading…
Reference in New Issue
Block a user