mirror of
https://github.com/open-webui/open-webui
synced 2025-01-18 00:30:51 +00:00
feat: prototype frontend web search integration
This commit is contained in:
parent
619c2f9c71
commit
2660a6e5b8
@ -93,6 +93,7 @@ from config import (
|
|||||||
CHUNK_OVERLAP,
|
CHUNK_OVERLAP,
|
||||||
RAG_TEMPLATE,
|
RAG_TEMPLATE,
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
@ -538,18 +539,23 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
|||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
|
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
|
||||||
# Check if the URL is valid
|
# Check if the URL is valid
|
||||||
if not validate_url(url):
|
if not validate_url(url):
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
return WebBaseLoader(url, verify_ssl=verify_ssl)
|
return WebBaseLoader(
|
||||||
|
url,
|
||||||
|
verify_ssl=verify_ssl,
|
||||||
|
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_url(url: Union[str, Sequence[str]]):
|
def validate_url(url: Union[str, Sequence[str]]):
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
if isinstance(validators.url(url), validators.ValidationError):
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
if not ENABLE_LOCAL_WEB_FETCH:
|
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
# Get IPv4 and IPv6 addresses
|
# Get IPv4 and IPv6 addresses
|
||||||
@ -593,7 +599,7 @@ def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
|
|||||||
)
|
)
|
||||||
urls = [result.link for result in web_results]
|
urls = [result.link for result in web_results]
|
||||||
loader = get_web_loader(urls)
|
loader = get_web_loader(urls)
|
||||||
data = loader.load()
|
data = loader.aload()
|
||||||
|
|
||||||
collection_name = form_data.collection_name
|
collection_name = form_data.collection_name
|
||||||
if collection_name == "":
|
if collection_name == "":
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
|||||||
"Accept-Encoding": "gzip",
|
"Accept-Encoding": "gzip",
|
||||||
"X-Subscription-Token": api_key,
|
"X-Subscription-Token": api_key,
|
||||||
}
|
}
|
||||||
params = {"q": query, "count": WEB_SEARCH_RESULT_COUNT}
|
params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
|
||||||
|
|
||||||
response = requests.get(url, headers=headers, params=params)
|
response = requests.get(url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
|||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||||
)
|
)
|
||||||
for result in results[:WEB_SEARCH_RESULT_COUNT]
|
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||||
]
|
]
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -27,7 +27,7 @@ def search_google_pse(
|
|||||||
"cx": search_engine_id,
|
"cx": search_engine_id,
|
||||||
"q": query,
|
"q": query,
|
||||||
"key": api_key,
|
"key": api_key,
|
||||||
"num": WEB_SEARCH_RESULT_COUNT,
|
"num": RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("GET", url, headers=headers, params=params)
|
response = requests.request("GET", url, headers=headers, params=params)
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -40,5 +40,5 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
|
|||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||||
)
|
)
|
||||||
for result in sorted_results[:WEB_SEARCH_RESULT_COUNT]
|
for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||||
]
|
]
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
|
|||||||
title=result.get("title"),
|
title=result.get("title"),
|
||||||
snippet=result.get("description"),
|
snippet=result.get("description"),
|
||||||
)
|
)
|
||||||
for result in results[:WEB_SEARCH_RESULT_COUNT]
|
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||||
]
|
]
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
|
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -39,5 +39,5 @@ def search_serpstack(
|
|||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||||
)
|
)
|
||||||
for result in results[:WEB_SEARCH_RESULT_COUNT]
|
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||||
]
|
]
|
||||||
|
@ -549,7 +549,10 @@ BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
|
|||||||
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
|
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
|
||||||
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
|
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
|
||||||
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
|
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
|
||||||
WEB_SEARCH_RESULT_COUNT = int(os.getenv("WEB_SEARCH_RESULT_COUNT", "10"))
|
RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "10"))
|
||||||
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
|
||||||
|
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# Transcribe
|
# Transcribe
|
||||||
|
@ -318,3 +318,119 @@ export const generateTitle = async (
|
|||||||
|
|
||||||
return res?.choices[0]?.message?.content ?? 'New Chat';
|
return res?.choices[0]?.message?.content ?? 'New Chat';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const generateSearchQuery = async (
|
||||||
|
token: string = '',
|
||||||
|
// template: string,
|
||||||
|
model: string,
|
||||||
|
prompt: string,
|
||||||
|
url: string = OPENAI_API_BASE_URL
|
||||||
|
): Promise<string | undefined> => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
// TODO: Allow users to specify the prompt
|
||||||
|
// template = promptTemplate(template, prompt);
|
||||||
|
|
||||||
|
// Get the current date in the format "January 20, 2024"
|
||||||
|
const currentDate = new Intl.DateTimeFormat('en-US', {
|
||||||
|
year: 'numeric',
|
||||||
|
month: 'long',
|
||||||
|
day: '2-digit'
|
||||||
|
}).format(new Date());
|
||||||
|
const yesterdayDate = new Intl.DateTimeFormat('en-US', {
|
||||||
|
year: 'numeric',
|
||||||
|
month: 'long',
|
||||||
|
day: '2-digit'
|
||||||
|
}).format(new Date());
|
||||||
|
|
||||||
|
// console.log(template);
|
||||||
|
|
||||||
|
const res = await fetch(`${url}/chat/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: model,
|
||||||
|
// Few shot prompting
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}.`
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: `Previous Questions:
|
||||||
|
- Who is the president of France?
|
||||||
|
|
||||||
|
Current Question: What about Mexico?`
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'President of Mexico'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: `Previous questions:
|
||||||
|
- When is the next formula 1 grand prix?
|
||||||
|
|
||||||
|
Current Question: Where is it being hosted?`
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'location of next formula 1 grand prix'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'Current Question: What type of printhead does the Epson F2270 DTG printer use?'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'Epson F2270 DTG printer printhead'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'What were the news yesterday?'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: `news ${yesterdayDate}`
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'What is the current weather in Paris?'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: `weather in Paris ${currentDate}`
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: `Current Question: ${prompt}`
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream: false,
|
||||||
|
// Restricting the max tokens to 30 to avoid long search queries
|
||||||
|
max_tokens: 30
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
if ('detail' in err) {
|
||||||
|
error = err.detail;
|
||||||
|
}
|
||||||
|
return undefined;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? undefined;
|
||||||
|
};
|
||||||
|
@ -507,3 +507,44 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod
|
|||||||
|
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const runWebSearch = async (
|
||||||
|
token: string,
|
||||||
|
query: string,
|
||||||
|
collection_name?: string
|
||||||
|
): Promise<SearchDocument | undefined> => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${RAG_API_BASE_URL}/websearch`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
query,
|
||||||
|
collection_name
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
error = err.detail;
|
||||||
|
return undefined;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface SearchDocument {
|
||||||
|
status: boolean;
|
||||||
|
collection_name: string;
|
||||||
|
filenames: string[];
|
||||||
|
}
|
||||||
|
@ -30,8 +30,8 @@
|
|||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
} from '$lib/apis/chats';
|
} from '$lib/apis/chats';
|
||||||
import { queryCollection, queryDoc } from '$lib/apis/rag';
|
import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag';
|
||||||
import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
|
import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
|
||||||
|
|
||||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||||
import Messages from '$lib/components/chat/Messages.svelte';
|
import Messages from '$lib/components/chat/Messages.svelte';
|
||||||
@ -55,6 +55,8 @@
|
|||||||
let selectedModels = [''];
|
let selectedModels = [''];
|
||||||
let atSelectedModel = '';
|
let atSelectedModel = '';
|
||||||
|
|
||||||
|
let useWebSearch = false;
|
||||||
|
|
||||||
let selectedModelfile = null;
|
let selectedModelfile = null;
|
||||||
$: selectedModelfile =
|
$: selectedModelfile =
|
||||||
selectedModels.length === 1 &&
|
selectedModels.length === 1 &&
|
||||||
@ -275,6 +277,39 @@
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (useWebSearch) {
|
||||||
|
// TODO: Toasts are temporary indicators for web search
|
||||||
|
toast.info($i18n.t('Generating search query'));
|
||||||
|
const searchQuery = await generateChatSearchQuery(prompt);
|
||||||
|
if (searchQuery) {
|
||||||
|
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
|
||||||
|
const searchDocUuid = uuidv4();
|
||||||
|
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
|
||||||
|
if (searchDocument) {
|
||||||
|
const parentMessage = history.messages[parentId];
|
||||||
|
if (!parentMessage.files) {
|
||||||
|
parentMessage.files = [];
|
||||||
|
}
|
||||||
|
parentMessage.files.push({
|
||||||
|
collection_name: searchDocument.collection_name,
|
||||||
|
name: searchQuery,
|
||||||
|
type: 'doc',
|
||||||
|
upload_status: true,
|
||||||
|
error: ""
|
||||||
|
});
|
||||||
|
// Find message in messages and update it
|
||||||
|
const messageIndex = messages.findIndex((message) => message.id === parentId);
|
||||||
|
if (messageIndex !== -1) {
|
||||||
|
messages[messageIndex] = parentMessage;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
toast.warning($i18n.t('No search results found'));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
toast.warning($i18n.t('No search query generated'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (model?.external) {
|
if (model?.external) {
|
||||||
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
|
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
|
||||||
} else if (model) {
|
} else if (model) {
|
||||||
@ -807,6 +842,30 @@
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: Add support for adding all the user's messages as context, and not just the last message
|
||||||
|
const generateChatSearchQuery = async (userPrompt: string) => {
|
||||||
|
const model = $models.find((model) => model.id === selectedModels[0]);
|
||||||
|
|
||||||
|
// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
|
||||||
|
const titleModelId =
|
||||||
|
model?.external ?? false
|
||||||
|
? $settings?.title?.modelExternal ?? selectedModels[0]
|
||||||
|
: $settings?.title?.model ?? selectedModels[0];
|
||||||
|
const titleModel = $models.find((model) => model.id === titleModelId);
|
||||||
|
|
||||||
|
console.log(titleModel);
|
||||||
|
return await generateSearchQuery(
|
||||||
|
localStorage.token,
|
||||||
|
titleModelId,
|
||||||
|
userPrompt,
|
||||||
|
titleModel?.external ?? false
|
||||||
|
? titleModel?.source?.toLowerCase() === 'litellm'
|
||||||
|
? `${LITELLM_API_BASE_URL}/v1`
|
||||||
|
: `${OPENAI_API_BASE_URL}`
|
||||||
|
: `${OLLAMA_API_BASE_URL}/v1`
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
const setChatTitle = async (_chatId, _title) => {
|
const setChatTitle = async (_chatId, _title) => {
|
||||||
if (_chatId === $chatId) {
|
if (_chatId === $chatId) {
|
||||||
title = _title;
|
title = _title;
|
||||||
@ -906,6 +965,7 @@
|
|||||||
bind:prompt
|
bind:prompt
|
||||||
bind:autoScroll
|
bind:autoScroll
|
||||||
bind:selectedModel={atSelectedModel}
|
bind:selectedModel={atSelectedModel}
|
||||||
|
bind:useWebSearch
|
||||||
{messages}
|
{messages}
|
||||||
{submitPrompt}
|
{submitPrompt}
|
||||||
{stopResponse}
|
{stopResponse}
|
||||||
|
@ -30,7 +30,7 @@
|
|||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
} from '$lib/apis/chats';
|
} from '$lib/apis/chats';
|
||||||
import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
|
import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
|
||||||
|
|
||||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||||
import Messages from '$lib/components/chat/Messages.svelte';
|
import Messages from '$lib/components/chat/Messages.svelte';
|
||||||
@ -43,6 +43,7 @@
|
|||||||
WEBUI_BASE_URL
|
WEBUI_BASE_URL
|
||||||
} from '$lib/constants';
|
} from '$lib/constants';
|
||||||
import { createOpenAITextStream } from '$lib/apis/streaming';
|
import { createOpenAITextStream } from '$lib/apis/streaming';
|
||||||
|
import { runWebSearch } from '$lib/apis/rag';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
@ -59,6 +60,8 @@
|
|||||||
let selectedModels = [''];
|
let selectedModels = [''];
|
||||||
let atSelectedModel = '';
|
let atSelectedModel = '';
|
||||||
|
|
||||||
|
let useWebSearch = false;
|
||||||
|
|
||||||
let selectedModelfile = null;
|
let selectedModelfile = null;
|
||||||
|
|
||||||
$: selectedModelfile =
|
$: selectedModelfile =
|
||||||
@ -287,6 +290,39 @@
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (useWebSearch) {
|
||||||
|
// TODO: Toasts are temporary indicators for web search
|
||||||
|
toast.info($i18n.t('Generating search query'));
|
||||||
|
const searchQuery = await generateChatSearchQuery(prompt);
|
||||||
|
if (searchQuery) {
|
||||||
|
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
|
||||||
|
const searchDocUuid = uuidv4();
|
||||||
|
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
|
||||||
|
if (searchDocument) {
|
||||||
|
const parentMessage = history.messages[parentId];
|
||||||
|
if (!parentMessage.files) {
|
||||||
|
parentMessage.files = [];
|
||||||
|
}
|
||||||
|
parentMessage.files.push({
|
||||||
|
collection_name: searchDocument.collection_name,
|
||||||
|
name: searchQuery,
|
||||||
|
type: 'doc',
|
||||||
|
upload_status: true,
|
||||||
|
error: ""
|
||||||
|
});
|
||||||
|
// Find message in messages and update it
|
||||||
|
const messageIndex = messages.findIndex((message) => message.id === parentId);
|
||||||
|
if (messageIndex !== -1) {
|
||||||
|
messages[messageIndex] = parentMessage;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
toast.warning($i18n.t('No search results found'));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
toast.warning($i18n.t('No search query generated'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (model?.external) {
|
if (model?.external) {
|
||||||
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
|
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
|
||||||
} else if (model) {
|
} else if (model) {
|
||||||
@ -819,6 +855,30 @@
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: Add support for adding all the user's messages as context, and not just the last message
|
||||||
|
const generateChatSearchQuery = async (userPrompt: string) => {
|
||||||
|
const model = $models.find((model) => model.id === selectedModels[0]);
|
||||||
|
|
||||||
|
// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
|
||||||
|
const titleModelId =
|
||||||
|
model?.external ?? false
|
||||||
|
? $settings?.title?.modelExternal ?? selectedModels[0]
|
||||||
|
: $settings?.title?.model ?? selectedModels[0];
|
||||||
|
const titleModel = $models.find((model) => model.id === titleModelId);
|
||||||
|
|
||||||
|
console.log(titleModel);
|
||||||
|
return await generateSearchQuery(
|
||||||
|
localStorage.token,
|
||||||
|
titleModelId,
|
||||||
|
userPrompt,
|
||||||
|
titleModel?.external ?? false
|
||||||
|
? titleModel?.source?.toLowerCase() === 'litellm'
|
||||||
|
? `${LITELLM_API_BASE_URL}/v1`
|
||||||
|
: `${OPENAI_API_BASE_URL}`
|
||||||
|
: `${OLLAMA_API_BASE_URL}/v1`
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
const setChatTitle = async (_chatId, _title) => {
|
const setChatTitle = async (_chatId, _title) => {
|
||||||
if (_chatId === $chatId) {
|
if (_chatId === $chatId) {
|
||||||
title = _title;
|
title = _title;
|
||||||
@ -929,6 +989,7 @@
|
|||||||
bind:prompt
|
bind:prompt
|
||||||
bind:autoScroll
|
bind:autoScroll
|
||||||
bind:selectedModel={atSelectedModel}
|
bind:selectedModel={atSelectedModel}
|
||||||
|
bind:useWebSearch
|
||||||
suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
|
suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
|
||||||
{messages}
|
{messages}
|
||||||
{submitPrompt}
|
{submitPrompt}
|
||||||
|
Loading…
Reference in New Issue
Block a user