diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index a57a126c8..cb9d35ad8 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -93,6 +93,7 @@ from config import ( CHUNK_OVERLAP, RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, + RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) from constants import ERROR_MESSAGES @@ -538,18 +539,23 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): detail=ERROR_MESSAGES.DEFAULT(e), ) + def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): # Check if the URL is valid if not validate_url(url): raise ValueError(ERROR_MESSAGES.INVALID_URL) - return WebBaseLoader(url, verify_ssl=verify_ssl) + return WebBaseLoader( + url, + verify_ssl=verify_ssl, + requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + ) def validate_url(url: Union[str, Sequence[str]]): if isinstance(url, str): if isinstance(validators.url(url), validators.ValidationError): raise ValueError(ERROR_MESSAGES.INVALID_URL) - if not ENABLE_LOCAL_WEB_FETCH: + if not ENABLE_RAG_LOCAL_WEB_FETCH: # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses parsed_url = urllib.parse.urlparse(url) # Get IPv4 and IPv6 addresses @@ -593,7 +599,7 @@ def store_websearch(form_data: SearchForm, user=Depends(get_current_user)): ) urls = [result.link for result in web_results] loader = get_web_loader(urls) - data = loader.load() + data = loader.aload() collection_name = form_data.collection_name if collection_name == "": diff --git a/backend/apps/rag/search/brave.py b/backend/apps/rag/search/brave.py index 91efaf396..50e364ca3 100644 --- a/backend/apps/rag/search/brave.py +++ b/backend/apps/rag/search/brave.py @@ -3,7 +3,7 @@ import logging import requests from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT +from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]: "Accept-Encoding": "gzip", "X-Subscription-Token": api_key, } - params = {"q": query, "count": WEB_SEARCH_RESULT_COUNT} + params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT} response = requests.get(url, headers=headers, params=params) response.raise_for_status() @@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]: SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("snippet") ) - for result in results[:WEB_SEARCH_RESULT_COUNT] + for result in results[:RAG_WEB_SEARCH_RESULT_COUNT] ] diff --git a/backend/apps/rag/search/google_pse.py b/backend/apps/rag/search/google_pse.py index 7b4a757a3..9cc22402d 100644 --- a/backend/apps/rag/search/google_pse.py +++ b/backend/apps/rag/search/google_pse.py @@ -4,7 +4,7 @@ import logging import requests from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT +from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -27,7 +27,7 @@ def search_google_pse( "cx": search_engine_id, "q": query, "key": api_key, - "num": WEB_SEARCH_RESULT_COUNT, + "num": RAG_WEB_SEARCH_RESULT_COUNT, } response = requests.request("GET", url, headers=headers, params=params) diff --git a/backend/apps/rag/search/searxng.py b/backend/apps/rag/search/searxng.py index dd2afd5f7..9848439ad 100644 --- a/backend/apps/rag/search/searxng.py +++ b/backend/apps/rag/search/searxng.py @@ -3,7 +3,7 @@ import logging import requests from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT +from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -40,5 +40,5 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]: SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("content") ) - for result in sorted_results[:WEB_SEARCH_RESULT_COUNT] + for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT] ] diff --git a/backend/apps/rag/search/serper.py b/backend/apps/rag/search/serper.py index 8244ae0b6..49146b304 100644 --- a/backend/apps/rag/search/serper.py +++ b/backend/apps/rag/search/serper.py @@ -4,7 +4,7 @@ import logging import requests from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT +from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]: title=result.get("title"), snippet=result.get("description"), ) - for result in results[:WEB_SEARCH_RESULT_COUNT] + for result in results[:RAG_WEB_SEARCH_RESULT_COUNT] ] diff --git a/backend/apps/rag/search/serpstack.py b/backend/apps/rag/search/serpstack.py index 5cbf601ec..68222aa2b 100644 --- a/backend/apps/rag/search/serpstack.py +++ b/backend/apps/rag/search/serpstack.py @@ -4,7 +4,7 @@ import logging import requests from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT +from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -39,5 +39,5 @@ def search_serpstack( SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("snippet") ) - for result in results[:WEB_SEARCH_RESULT_COUNT] + for result in results[:RAG_WEB_SEARCH_RESULT_COUNT] ] diff --git a/backend/config.py b/backend/config.py index 513fab482..67dad6ae4 100644 --- a/backend/config.py +++ b/backend/config.py @@ -549,7 +549,10 @@ BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "") SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "") SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true" SERPER_API_KEY = os.getenv("SERPER_API_KEY", "") -WEB_SEARCH_RESULT_COUNT = int(os.getenv("WEB_SEARCH_RESULT_COUNT", "10")) +RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "10")) +RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int( + os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10") +) #################################### # Transcribe diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 944e2d40d..f054afceb 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -318,3 +318,119 @@ export const generateTitle = async ( return res?.choices[0]?.message?.content ?? 'New Chat'; }; + +export const generateSearchQuery = async ( + token: string = '', + // template: string, + model: string, + prompt: string, + url: string = OPENAI_API_BASE_URL +): Promise => { + let error = null; + + // TODO: Allow users to specify the prompt + // template = promptTemplate(template, prompt); + + // Get the current date in the format "January 20, 2024" + const currentDate = new Intl.DateTimeFormat('en-US', { + year: 'numeric', + month: 'long', + day: '2-digit' + }).format(new Date()); + const yesterdayDate = new Intl.DateTimeFormat('en-US', { + year: 'numeric', + month: 'long', + day: '2-digit' + }).format(new Date()); + + // console.log(template); + + const res = await fetch(`${url}/chat/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + // Few shot prompting + messages: [ + { + role: 'assistant', + content: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}.` + }, + { + role: 'user', + content: `Previous Questions: +- Who is the president of France? + +Current Question: What about Mexico?` + }, + { + role: 'assistant', + content: 'President of Mexico' + }, + { + role: 'user', + content: `Previous questions: +- When is the next formula 1 grand prix? + +Current Question: Where is it being hosted?` + }, + { + role: 'assistant', + content: 'location of next formula 1 grand prix' + }, + { + role: 'user', + content: 'Current Question: What type of printhead does the Epson F2270 DTG printer use?' + }, + { + role: 'assistant', + content: 'Epson F2270 DTG printer printhead' + }, + { + role: 'user', + content: 'What were the news yesterday?' + }, + { + role: 'assistant', + content: `news ${yesterdayDate}` + }, + { + role: 'user', + content: 'What is the current weather in Paris?' + }, + { + role: 'assistant', + content: `weather in Paris ${currentDate}` + }, + { + role: 'user', + content: `Current Question: ${prompt}` + } + ], + stream: false, + // Restricting the max tokens to 30 to avoid long search queries + max_tokens: 30 + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return undefined; + }); + + if (error) { + throw error; + } + + return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? undefined; +}; diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ccf166dab..727c64521 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -507,3 +507,44 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod return res; }; + +export const runWebSearch = async ( + token: string, + query: string, + collection_name?: string +): Promise => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/websearch`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + query, + collection_name + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return undefined; + }); + + if (error) { + throw error; + } + + return res; +}; + +export interface SearchDocument { + status: boolean; + collection_name: string; + filenames: string[]; +} diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index b36c6b3e2..fe413c6ed 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -30,8 +30,8 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryCollection, queryDoc } from '$lib/apis/rag'; - import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai'; + import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag'; + import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; import Messages from '$lib/components/chat/Messages.svelte'; @@ -55,6 +55,8 @@ let selectedModels = ['']; let atSelectedModel = ''; + let useWebSearch = false; + let selectedModelfile = null; $: selectedModelfile = selectedModels.length === 1 && @@ -275,6 +277,39 @@ ]; } + if (useWebSearch) { + // TODO: Toasts are temporary indicators for web search + toast.info($i18n.t('Generating search query')); + const searchQuery = await generateChatSearchQuery(prompt); + if (searchQuery) { + toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery })); + const searchDocUuid = uuidv4(); + const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid); + if (searchDocument) { + const parentMessage = history.messages[parentId]; + if (!parentMessage.files) { + parentMessage.files = []; + } + parentMessage.files.push({ + collection_name: searchDocument.collection_name, + name: searchQuery, + type: 'doc', + upload_status: true, + error: "" + }); + // Find message in messages and update it + const messageIndex = messages.findIndex((message) => message.id === parentId); + if (messageIndex !== -1) { + messages[messageIndex] = parentMessage; + } + } else { + toast.warning($i18n.t('No search results found')); + } + } else { + toast.warning($i18n.t('No search query generated')); + } + } + if (model?.external) { await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); } else if (model) { @@ -807,6 +842,30 @@ } }; + // TODO: Add support for adding all the user's messages as context, and not just the last message + const generateChatSearchQuery = async (userPrompt: string) => { + const model = $models.find((model) => model.id === selectedModels[0]); + + // TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation) + const titleModelId = + model?.external ?? false + ? $settings?.title?.modelExternal ?? selectedModels[0] + : $settings?.title?.model ?? selectedModels[0]; + const titleModel = $models.find((model) => model.id === titleModelId); + + console.log(titleModel); + return await generateSearchQuery( + localStorage.token, + titleModelId, + userPrompt, + titleModel?.external ?? false + ? titleModel?.source?.toLowerCase() === 'litellm' + ? `${LITELLM_API_BASE_URL}/v1` + : `${OPENAI_API_BASE_URL}` + : `${OLLAMA_API_BASE_URL}/v1` + ); + }; + const setChatTitle = async (_chatId, _title) => { if (_chatId === $chatId) { title = _title; @@ -906,6 +965,7 @@ bind:prompt bind:autoScroll bind:selectedModel={atSelectedModel} + bind:useWebSearch {messages} {submitPrompt} {stopResponse} diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index ccf85317e..713ac0c27 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -30,7 +30,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai'; + import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; import Messages from '$lib/components/chat/Messages.svelte'; @@ -43,6 +43,7 @@ WEBUI_BASE_URL } from '$lib/constants'; import { createOpenAITextStream } from '$lib/apis/streaming'; + import { runWebSearch } from '$lib/apis/rag'; const i18n = getContext('i18n'); @@ -59,6 +60,8 @@ let selectedModels = ['']; let atSelectedModel = ''; + let useWebSearch = false; + let selectedModelfile = null; $: selectedModelfile = @@ -287,6 +290,39 @@ ]; } + if (useWebSearch) { + // TODO: Toasts are temporary indicators for web search + toast.info($i18n.t('Generating search query')); + const searchQuery = await generateChatSearchQuery(prompt); + if (searchQuery) { + toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery })); + const searchDocUuid = uuidv4(); + const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid); + if (searchDocument) { + const parentMessage = history.messages[parentId]; + if (!parentMessage.files) { + parentMessage.files = []; + } + parentMessage.files.push({ + collection_name: searchDocument.collection_name, + name: searchQuery, + type: 'doc', + upload_status: true, + error: "" + }); + // Find message in messages and update it + const messageIndex = messages.findIndex((message) => message.id === parentId); + if (messageIndex !== -1) { + messages[messageIndex] = parentMessage; + } + } else { + toast.warning($i18n.t('No search results found')); + } + } else { + toast.warning($i18n.t('No search query generated')); + } + } + if (model?.external) { await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); } else if (model) { @@ -819,6 +855,30 @@ } }; + // TODO: Add support for adding all the user's messages as context, and not just the last message + const generateChatSearchQuery = async (userPrompt: string) => { + const model = $models.find((model) => model.id === selectedModels[0]); + + // TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation) + const titleModelId = + model?.external ?? false + ? $settings?.title?.modelExternal ?? selectedModels[0] + : $settings?.title?.model ?? selectedModels[0]; + const titleModel = $models.find((model) => model.id === titleModelId); + + console.log(titleModel); + return await generateSearchQuery( + localStorage.token, + titleModelId, + userPrompt, + titleModel?.external ?? false + ? titleModel?.source?.toLowerCase() === 'litellm' + ? `${LITELLM_API_BASE_URL}/v1` + : `${OPENAI_API_BASE_URL}` + : `${OLLAMA_API_BASE_URL}/v1` + ); + }; + const setChatTitle = async (_chatId, _title) => { if (_chatId === $chatId) { title = _title; @@ -929,6 +989,7 @@ bind:prompt bind:autoScroll bind:selectedModel={atSelectedModel} + bind:useWebSearch suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions} {messages} {submitPrompt}