feat: memory integration

This commit is contained in:
Timothy J. Baek 2024-05-19 08:40:46 -07:00
parent 2638ae6a93
commit febab58821
4 changed files with 66 additions and 10 deletions

View File

@ -71,7 +71,7 @@ class QueryMemoryForm(BaseModel):
content: str
@router.post("/query", response_model=Optional[MemoryModel])
@router.post("/query")
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):

View File

@ -26,8 +26,8 @@
if (res) {
console.log(res);
toast.success('Memory added successfully');
content = '';
show = false;
dispatch('save');
}

View File

@ -41,6 +41,7 @@
import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants';
import { WEBUI_BASE_URL } from '$lib/constants';
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
const i18n = getContext('i18n');
@ -254,6 +255,26 @@
const sendPrompt = async (prompt, parentId, modelId = null) => {
const _chatId = JSON.parse(JSON.stringify($chatId));
let userContext = null;
if ($settings?.memory ?? false) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
console.log(userContext);
}
}
await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
async (modelId) => {
@ -270,6 +291,7 @@
role: 'assistant',
content: '',
model: model.id,
userContext: userContext,
timestamp: Math.floor(Date.now() / 1000) // Unix epoch
};
@ -311,10 +333,13 @@
scrollToBottom();
const messagesBody = [
$settings.system
$settings.system || responseMessage?.userContext
? {
role: 'system',
content: $settings.system
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
}
: undefined,
...messages
@ -567,10 +592,13 @@
model: model.id,
stream: true,
messages: [
$settings.system
$settings.system || responseMessage?.userContext
? {
role: 'system',
content: $settings.system
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
}
: undefined,
...messages

View File

@ -43,6 +43,7 @@
WEBUI_BASE_URL
} from '$lib/constants';
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
const i18n = getContext('i18n');
@ -260,6 +261,26 @@
const sendPrompt = async (prompt, parentId, modelId = null) => {
const _chatId = JSON.parse(JSON.stringify($chatId));
let userContext = null;
if ($settings?.memory ?? false) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
console.log(userContext);
}
}
await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
async (modelId) => {
@ -317,10 +338,13 @@
scrollToBottom();
const messagesBody = [
$settings.system
$settings.system || responseMessage?.userContext
? {
role: 'system',
content: $settings.system
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
}
: undefined,
...messages
@ -573,10 +597,13 @@
model: model.id,
stream: true,
messages: [
$settings.system
$settings.system || responseMessage?.userContext
? {
role: 'system',
content: $settings.system
content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
}
: undefined,
...messages
@ -705,6 +732,7 @@
} catch (error) {
await handleOpenAIError(error, null, model, responseMessage);
}
messages = messages;
stopResponseFlag = false;
await tick();