feat: prototype frontend web search integration

This commit is contained in:
Jun Siang Cheah 2024-05-11 23:12:52 +08:00
parent 619c2f9c71
commit 2660a6e5b8
11 changed files with 305 additions and 18 deletions

View File

@ -93,6 +93,7 @@ from config import (
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -538,18 +539,23 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
) )
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid # Check if the URL is valid
if not validate_url(url): if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url, verify_ssl=verify_ssl) return WebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
)
def validate_url(url: Union[str, Sequence[str]]): def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str): if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH: if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses # Get IPv4 and IPv6 addresses
@ -593,7 +599,7 @@ def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
) )
urls = [result.link for result in web_results] urls = [result.link for result in web_results]
loader = get_web_loader(urls) loader = get_web_loader(urls)
data = loader.load() data = loader.aload()
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == "": if collection_name == "":

View File

@ -3,7 +3,7 @@ import logging
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
"Accept-Encoding": "gzip", "Accept-Encoding": "gzip",
"X-Subscription-Token": api_key, "X-Subscription-Token": api_key,
} }
params = {"q": query, "count": WEB_SEARCH_RESULT_COUNT} params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
response = requests.get(url, headers=headers, params=params) response = requests.get(url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
) )
for result in results[:WEB_SEARCH_RESULT_COUNT] for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
] ]

View File

@ -4,7 +4,7 @@ import logging
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -27,7 +27,7 @@ def search_google_pse(
"cx": search_engine_id, "cx": search_engine_id,
"q": query, "q": query,
"key": api_key, "key": api_key,
"num": WEB_SEARCH_RESULT_COUNT, "num": RAG_WEB_SEARCH_RESULT_COUNT,
} }
response = requests.request("GET", url, headers=headers, params=params) response = requests.request("GET", url, headers=headers, params=params)

View File

@ -3,7 +3,7 @@ import logging
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -40,5 +40,5 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content") link=result["url"], title=result.get("title"), snippet=result.get("content")
) )
for result in sorted_results[:WEB_SEARCH_RESULT_COUNT] for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
] ]

View File

@ -4,7 +4,7 @@ import logging
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
title=result.get("title"), title=result.get("title"),
snippet=result.get("description"), snippet=result.get("description"),
) )
for result in results[:WEB_SEARCH_RESULT_COUNT] for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
] ]

View File

@ -4,7 +4,7 @@ import logging
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -39,5 +39,5 @@ def search_serpstack(
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
) )
for result in results[:WEB_SEARCH_RESULT_COUNT] for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
] ]

View File

@ -549,7 +549,10 @@ BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "") SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true" SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "") SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
WEB_SEARCH_RESULT_COUNT = int(os.getenv("WEB_SEARCH_RESULT_COUNT", "10")) RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "10"))
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
)
#################################### ####################################
# Transcribe # Transcribe

View File

@ -318,3 +318,119 @@ export const generateTitle = async (
return res?.choices[0]?.message?.content ?? 'New Chat'; return res?.choices[0]?.message?.content ?? 'New Chat';
}; };
export const generateSearchQuery = async (
token: string = '',
// template: string,
model: string,
prompt: string,
url: string = OPENAI_API_BASE_URL
): Promise<string | undefined> => {
let error = null;
// TODO: Allow users to specify the prompt
// template = promptTemplate(template, prompt);
// Get the current date in the format "January 20, 2024"
const currentDate = new Intl.DateTimeFormat('en-US', {
year: 'numeric',
month: 'long',
day: '2-digit'
}).format(new Date());
const yesterdayDate = new Intl.DateTimeFormat('en-US', {
year: 'numeric',
month: 'long',
day: '2-digit'
}).format(new Date());
// console.log(template);
const res = await fetch(`${url}/chat/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
// Few shot prompting
messages: [
{
role: 'assistant',
content: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}.`
},
{
role: 'user',
content: `Previous Questions:
- Who is the president of France?
Current Question: What about Mexico?`
},
{
role: 'assistant',
content: 'President of Mexico'
},
{
role: 'user',
content: `Previous questions:
- When is the next formula 1 grand prix?
Current Question: Where is it being hosted?`
},
{
role: 'assistant',
content: 'location of next formula 1 grand prix'
},
{
role: 'user',
content: 'Current Question: What type of printhead does the Epson F2270 DTG printer use?'
},
{
role: 'assistant',
content: 'Epson F2270 DTG printer printhead'
},
{
role: 'user',
content: 'What were the news yesterday?'
},
{
role: 'assistant',
content: `news ${yesterdayDate}`
},
{
role: 'user',
content: 'What is the current weather in Paris?'
},
{
role: 'assistant',
content: `weather in Paris ${currentDate}`
},
{
role: 'user',
content: `Current Question: ${prompt}`
}
],
stream: false,
// Restricting the max tokens to 30 to avoid long search queries
max_tokens: 30
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return undefined;
});
if (error) {
throw error;
}
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? undefined;
};

View File

@ -507,3 +507,44 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod
return res; return res;
}; };
export const runWebSearch = async (
token: string,
query: string,
collection_name?: string
): Promise<SearchDocument | undefined> => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/websearch`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
query,
collection_name
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return undefined;
});
if (error) {
throw error;
}
return res;
};
export interface SearchDocument {
status: boolean;
collection_name: string;
filenames: string[];
}

View File

@ -30,8 +30,8 @@
getTagsById, getTagsById,
updateChatById updateChatById
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { queryCollection, queryDoc } from '$lib/apis/rag'; import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag';
import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai'; import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte'; import Messages from '$lib/components/chat/Messages.svelte';
@ -55,6 +55,8 @@
let selectedModels = ['']; let selectedModels = [''];
let atSelectedModel = ''; let atSelectedModel = '';
let useWebSearch = false;
let selectedModelfile = null; let selectedModelfile = null;
$: selectedModelfile = $: selectedModelfile =
selectedModels.length === 1 && selectedModels.length === 1 &&
@ -275,6 +277,39 @@
]; ];
} }
if (useWebSearch) {
// TODO: Toasts are temporary indicators for web search
toast.info($i18n.t('Generating search query'));
const searchQuery = await generateChatSearchQuery(prompt);
if (searchQuery) {
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
const searchDocUuid = uuidv4();
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
if (searchDocument) {
const parentMessage = history.messages[parentId];
if (!parentMessage.files) {
parentMessage.files = [];
}
parentMessage.files.push({
collection_name: searchDocument.collection_name,
name: searchQuery,
type: 'doc',
upload_status: true,
error: ""
});
// Find message in messages and update it
const messageIndex = messages.findIndex((message) => message.id === parentId);
if (messageIndex !== -1) {
messages[messageIndex] = parentMessage;
}
} else {
toast.warning($i18n.t('No search results found'));
}
} else {
toast.warning($i18n.t('No search query generated'));
}
}
if (model?.external) { if (model?.external) {
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
} else if (model) { } else if (model) {
@ -807,6 +842,30 @@
} }
}; };
// TODO: Add support for adding all the user's messages as context, and not just the last message
const generateChatSearchQuery = async (userPrompt: string) => {
const model = $models.find((model) => model.id === selectedModels[0]);
// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
const titleModelId =
model?.external ?? false
? $settings?.title?.modelExternal ?? selectedModels[0]
: $settings?.title?.model ?? selectedModels[0];
const titleModel = $models.find((model) => model.id === titleModelId);
console.log(titleModel);
return await generateSearchQuery(
localStorage.token,
titleModelId,
userPrompt,
titleModel?.external ?? false
? titleModel?.source?.toLowerCase() === 'litellm'
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
);
};
const setChatTitle = async (_chatId, _title) => { const setChatTitle = async (_chatId, _title) => {
if (_chatId === $chatId) { if (_chatId === $chatId) {
title = _title; title = _title;
@ -906,6 +965,7 @@
bind:prompt bind:prompt
bind:autoScroll bind:autoScroll
bind:selectedModel={atSelectedModel} bind:selectedModel={atSelectedModel}
bind:useWebSearch
{messages} {messages}
{submitPrompt} {submitPrompt}
{stopResponse} {stopResponse}

View File

@ -30,7 +30,7 @@
getTagsById, getTagsById,
updateChatById updateChatById
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai'; import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte'; import Messages from '$lib/components/chat/Messages.svelte';
@ -43,6 +43,7 @@
WEBUI_BASE_URL WEBUI_BASE_URL
} from '$lib/constants'; } from '$lib/constants';
import { createOpenAITextStream } from '$lib/apis/streaming'; import { createOpenAITextStream } from '$lib/apis/streaming';
import { runWebSearch } from '$lib/apis/rag';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -59,6 +60,8 @@
let selectedModels = ['']; let selectedModels = [''];
let atSelectedModel = ''; let atSelectedModel = '';
let useWebSearch = false;
let selectedModelfile = null; let selectedModelfile = null;
$: selectedModelfile = $: selectedModelfile =
@ -287,6 +290,39 @@
]; ];
} }
if (useWebSearch) {
// TODO: Toasts are temporary indicators for web search
toast.info($i18n.t('Generating search query'));
const searchQuery = await generateChatSearchQuery(prompt);
if (searchQuery) {
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
const searchDocUuid = uuidv4();
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
if (searchDocument) {
const parentMessage = history.messages[parentId];
if (!parentMessage.files) {
parentMessage.files = [];
}
parentMessage.files.push({
collection_name: searchDocument.collection_name,
name: searchQuery,
type: 'doc',
upload_status: true,
error: ""
});
// Find message in messages and update it
const messageIndex = messages.findIndex((message) => message.id === parentId);
if (messageIndex !== -1) {
messages[messageIndex] = parentMessage;
}
} else {
toast.warning($i18n.t('No search results found'));
}
} else {
toast.warning($i18n.t('No search query generated'));
}
}
if (model?.external) { if (model?.external) {
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
} else if (model) { } else if (model) {
@ -819,6 +855,30 @@
} }
}; };
// TODO: Add support for adding all the user's messages as context, and not just the last message
const generateChatSearchQuery = async (userPrompt: string) => {
const model = $models.find((model) => model.id === selectedModels[0]);
// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
const titleModelId =
model?.external ?? false
? $settings?.title?.modelExternal ?? selectedModels[0]
: $settings?.title?.model ?? selectedModels[0];
const titleModel = $models.find((model) => model.id === titleModelId);
console.log(titleModel);
return await generateSearchQuery(
localStorage.token,
titleModelId,
userPrompt,
titleModel?.external ?? false
? titleModel?.source?.toLowerCase() === 'litellm'
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
);
};
const setChatTitle = async (_chatId, _title) => { const setChatTitle = async (_chatId, _title) => {
if (_chatId === $chatId) { if (_chatId === $chatId) {
title = _title; title = _title;
@ -929,6 +989,7 @@
bind:prompt bind:prompt
bind:autoScroll bind:autoScroll
bind:selectedModel={atSelectedModel} bind:selectedModel={atSelectedModel}
bind:useWebSearch
suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions} suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
{messages} {messages}
{submitPrompt} {submitPrompt}