Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search

This commit is contained in:
Jun Siang Cheah 2024-05-20 19:53:23 +01:00
commit 224a578e6b
8 changed files with 127 additions and 97 deletions

View File

@ -28,6 +28,7 @@ from langchain_community.document_loaders import (
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
YoutubeLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
@ -823,6 +824,11 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_content_type in [
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):

View File

@ -35,6 +35,7 @@ chromadb==0.4.24
sentence-transformers==2.7.0
pypdf==4.2.0
docx2txt==0.8
python-pptx==0.6.23
unstructured==0.11.8
Markdown==3.6
pypandoc==1.13

View File

@ -213,7 +213,7 @@ __builtins__.input = input`);
<div class="p-1">{@html lang}</div>
<div class="flex items-center">
{#if ['', 'python'].includes(lang) && (lang === 'python' || checkPythonCode(code))}
{#if lang === 'python' || (lang === '' && checkPythonCode(code))}
{#if executing}
<div class="copy-code-button bg-none border-none p-1 cursor-not-allowed">Running</div>
{:else}

View File

@ -41,6 +41,44 @@
};
}, {});
const showPreviousMessage = (model) => {
groupedMessagesIdx[model] = Math.max(0, groupedMessagesIdx[model] - 1);
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds;
while (messageChildrenIds.length !== 0) {
messageId = messageChildrenIds.at(-1);
messageChildrenIds = history.messages[messageId].childrenIds;
}
history.currentId = messageId;
dispatch('change');
};
const showNextMessage = (model) => {
groupedMessagesIdx[model] = Math.min(
groupedMessages[model].messages.length - 1,
groupedMessagesIdx[model] + 1
);
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds;
while (messageChildrenIds.length !== 0) {
messageId = messageChildrenIds.at(-1);
messageChildrenIds = history.messages[messageId].childrenIds;
}
history.currentId = messageId;
dispatch('change');
};
onMount(async () => {
await tick();
currentMessageId = messages[messageIdx].id;
@ -97,42 +135,8 @@
isLastMessage={true}
{updateChatMessages}
{confirmEditResponseMessage}
showPreviousMessage={() => {
groupedMessagesIdx[model] = Math.max(0, groupedMessagesIdx[model] - 1);
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds;
while (messageChildrenIds.length !== 0) {
messageId = messageChildrenIds.at(-1);
messageChildrenIds = history.messages[messageId].childrenIds;
}
history.currentId = messageId;
dispatch('change');
}}
showNextMessage={() => {
groupedMessagesIdx[model] = Math.min(
groupedMessages[model].messages.length - 1,
groupedMessagesIdx[model] + 1
);
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id;
console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds;
while (messageChildrenIds.length !== 0) {
messageId = messageChildrenIds.at(-1);
messageChildrenIds = history.messages[messageId].childrenIds;
}
history.currentId = messageId;
dispatch('change');
}}
showPreviousMessage={() => showPreviousMessage(model)}
showNextMessage={() => showNextMessage(model)}
{rateMessage}
{copyToClipboard}
{continueGeneration}

View File

@ -10,7 +10,8 @@
crossorigin="anonymous"
src={src.startsWith(WEBUI_BASE_URL) ||
src.startsWith('https://www.gravatar.com/avatar/') ||
src.startsWith('data:')
src.startsWith('data:') ||
src.startsWith('/')
? src
: `/user.png`}
class=" w-8 object-cover rounded-full"

View File

@ -86,7 +86,9 @@ export const SUPPORTED_FILE_EXTENSIONS = [
'csv',
'txt',
'xls',
'xlsx'
'xlsx',
'pptx',
'ppt'
];
// Source: https://kit.svelte.dev/docs/modules#$env-static-public

View File

@ -261,28 +261,6 @@
const sendPrompt = async (prompt, parentId, modelId = null) => {
const _chatId = JSON.parse(JSON.stringify($chatId));
let userContext = null;
if ($settings?.memory ?? false) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
if (res.documents[0].length > 0) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
}
console.log(userContext);
}
}
await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
async (modelId) => {
@ -299,7 +277,7 @@
role: 'assistant',
content: '',
model: model.id,
userContext: userContext,
userContext: null,
timestamp: Math.floor(Date.now() / 1000) // Unix epoch
};
@ -315,6 +293,34 @@
];
}
await tick();
let userContext = null;
if ($settings?.memory ?? false) {
if (userContext === null) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
if (res.documents[0].length > 0) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000)
.toISOString()
.split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
}
console.log(userContext);
}
}
}
responseMessage.userContext = userContext;
if (useWebSearch) {
await runWebSearchForPrompt(model.id, parentId, responseMessageId);
}
@ -383,10 +389,11 @@
$settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
content: `${$settings?.system ?? ''}${
responseMessage?.userContext ?? null
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
: ''
}`
}
: undefined,
...messages
@ -642,10 +649,11 @@
$settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
content: `${$settings?.system ?? ''}${
responseMessage?.userContext ?? null
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
: ''
}`
}
: undefined,
...messages

View File

@ -268,28 +268,6 @@
const sendPrompt = async (prompt, parentId, modelId = null) => {
const _chatId = JSON.parse(JSON.stringify($chatId));
let userContext = null;
if ($settings?.memory ?? false) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
if (res.documents[0].length > 0) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
}
console.log(userContext);
}
}
await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
async (modelId) => {
@ -306,7 +284,7 @@
role: 'assistant',
content: '',
model: model.id,
userContext: userContext,
userContext: null,
timestamp: Math.floor(Date.now() / 1000) // Unix epoch
};
@ -322,6 +300,34 @@
];
}
await tick();
let userContext = null;
if ($settings?.memory ?? false) {
if (userContext === null) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
if (res.documents[0].length > 0) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000)
.toISOString()
.split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
}
console.log(userContext);
}
}
}
responseMessage.userContext = userContext;
if (useWebSearch) {
await runWebSearchForPrompt(model.id, parentId, responseMessageId);
}
@ -390,10 +396,11 @@
$settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
content: `${$settings?.system ?? ''}${
responseMessage?.userContext ?? null
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
: ''
}`
}
: undefined,
...messages
@ -649,10 +656,11 @@
$settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
content: `${$settings?.system ?? ''}${
responseMessage?.userContext ?? null
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
: ''
}`
}
: undefined,
...messages