refac: many model chat

This commit is contained in:
Timothy J. Baek 2024-08-16 18:54:30 +02:00
parent 4f47053e93
commit 28e3e6e8cb
2 changed files with 42 additions and 36 deletions

View File

@ -562,7 +562,7 @@
content: userPrompt, content: userPrompt,
files: _files.length > 0 ? _files : undefined, files: _files.length > 0 ? _files : undefined,
timestamp: Math.floor(Date.now() / 1000), // Unix epoch timestamp: Math.floor(Date.now() / 1000), // Unix epoch
models: selectedModels.filter((m, mIdx) => selectedModels.indexOf(m) === mIdx) models: selectedModels
}; };
// Add message to history and Set currentId to messageId // Add message to history and Set currentId to messageId
@ -582,7 +582,11 @@
return _responses; return _responses;
}; };
const sendPrompt = async (prompt, parentId, { modelId = null, newChat = false } = {}) => { const sendPrompt = async (
prompt,
parentId,
{ modelId = null, modelIdx = null, newChat = false } = {}
) => {
let _responses = []; let _responses = [];
// If modelId is provided, use it, else use selected model // If modelId is provided, use it, else use selected model
@ -594,7 +598,7 @@
// Create response messages for each selected model // Create response messages for each selected model
const responseMessageIds = {}; const responseMessageIds = {};
for (const modelId of selectedModelIds) { for (const [_modelIdx, modelId] of selectedModelIds.entries()) {
const model = $models.filter((m) => m.id === modelId).at(0); const model = $models.filter((m) => m.id === modelId).at(0);
if (model) { if (model) {
@ -607,6 +611,7 @@
content: '', content: '',
model: model.id, model: model.id,
modelName: model.name ?? model.id, modelName: model.name ?? model.id,
modelIdx: modelIdx ? modelIdx : _modelIdx,
userContext: null, userContext: null,
timestamp: Math.floor(Date.now() / 1000) // Unix epoch timestamp: Math.floor(Date.now() / 1000) // Unix epoch
}; };
@ -623,7 +628,7 @@
]; ];
} }
responseMessageIds[modelId] = responseMessageId; responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`] = responseMessageId;
} }
} }
await tick(); await tick();
@ -655,7 +660,7 @@
const _chatId = JSON.parse(JSON.stringify($chatId)); const _chatId = JSON.parse(JSON.stringify($chatId));
await Promise.all( await Promise.all(
selectedModelIds.map(async (modelId) => { selectedModelIds.map(async (modelId, _modelIdx) => {
console.log('modelId', modelId); console.log('modelId', modelId);
const model = $models.filter((m) => m.id === modelId).at(0); const model = $models.filter((m) => m.id === modelId).at(0);
@ -673,7 +678,8 @@
); );
} }
let responseMessageId = responseMessageIds[modelId]; let responseMessageId =
responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`];
let responseMessage = history.messages[responseMessageId]; let responseMessage = history.messages[responseMessageId];
let userContext = null; let userContext = null;
@ -1350,7 +1356,10 @@
} else { } else {
// If there are multiple models selected, use the model of the response message for regeneration // If there are multiple models selected, use the model of the response message for regeneration
// e.g. many model chat // e.g. many model chat
await sendPrompt(userPrompt, userMessage.id, { modelId: message.model }); await sendPrompt(userPrompt, userMessage.id, {
modelId: message.model,
modelIdx: message.modelIdx
});
} }
} }
}; };

View File

@ -26,24 +26,24 @@
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
let currentMessageId; let currentMessageId;
let groupedMessagesIdx = {};
let groupedMessages = {}; let groupedMessages = {};
let groupedMessagesIdx = {};
$: groupedMessages = parentMessage?.models.reduce((a, model) => { $: groupedMessages = parentMessage?.models.reduce((a, model, modelIdx) => {
// Find all messages that are children of the parent message and have the same model
const modelMessages = parentMessage?.childrenIds const modelMessages = parentMessage?.childrenIds
.map((id) => history.messages[id]) .map((id) => history.messages[id])
.filter((m) => m.model === model); .filter((m) => m.modelIdx === modelIdx);
return { return {
...a, ...a,
[model]: { messages: modelMessages } [modelIdx]: { messages: modelMessages }
}; };
}, {}); }, {});
const showPreviousMessage = (model) => { const showPreviousMessage = (modelIdx) => {
groupedMessagesIdx[model] = Math.max(0, groupedMessagesIdx[model] - 1); groupedMessagesIdx[modelIdx] = Math.max(0, groupedMessagesIdx[modelIdx] - 1);
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id; let messageId = groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]].id;
console.log(messageId); console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds; let messageChildrenIds = history.messages[messageId].childrenIds;
@ -54,17 +54,16 @@
} }
history.currentId = messageId; history.currentId = messageId;
dispatch('change'); dispatch('change');
}; };
const showNextMessage = (model) => { const showNextMessage = (modelIdx) => {
groupedMessagesIdx[model] = Math.min( groupedMessagesIdx[modelIdx] = Math.min(
groupedMessages[model].messages.length - 1, groupedMessages[modelIdx].messages.length - 1,
groupedMessagesIdx[model] + 1 groupedMessagesIdx[modelIdx] + 1
); );
let messageId = groupedMessages[model].messages[groupedMessagesIdx[model]].id; let messageId = groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]].id;
console.log(messageId); console.log(messageId);
let messageChildrenIds = history.messages[messageId].childrenIds; let messageChildrenIds = history.messages[messageId].childrenIds;
@ -75,7 +74,6 @@
} }
history.currentId = messageId; history.currentId = messageId;
dispatch('change'); dispatch('change');
}; };
@ -83,13 +81,12 @@
await tick(); await tick();
currentMessageId = messages[messageIdx].id; currentMessageId = messages[messageIdx].id;
for (const model of parentMessage?.models) { for (const [modelIdx, model] of parentMessage?.models.entries()) {
const idx = groupedMessages[model].messages.findIndex((m) => m.id === currentMessageId); const idx = groupedMessages[modelIdx].messages.findIndex((m) => m.id === currentMessageId);
if (idx !== -1) { if (idx !== -1) {
groupedMessagesIdx[model] = idx; groupedMessagesIdx[modelIdx] = idx;
} else { } else {
groupedMessagesIdx[model] = 0; groupedMessagesIdx[modelIdx] = 0;
} }
} }
}); });
@ -101,16 +98,16 @@
id="responses-container-{parentMessage.id}" id="responses-container-{parentMessage.id}"
> >
{#key currentMessageId} {#key currentMessageId}
{#each Object.keys(groupedMessages) as model} {#each Object.keys(groupedMessages) as modelIdx}
{#if groupedMessagesIdx[model] !== undefined && groupedMessages[model].messages.length > 0} {#if groupedMessagesIdx[modelIdx] !== undefined && groupedMessages[modelIdx].messages.length > 0}
<!-- svelte-ignore a11y-no-static-element-interactions --> <!-- svelte-ignore a11y-no-static-element-interactions -->
<!-- svelte-ignore a11y-click-events-have-key-events --> <!-- svelte-ignore a11y-click-events-have-key-events -->
{@const message = groupedMessages[model].messages[groupedMessagesIdx[model]]} {@const message = groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]]}
<div <div
class=" snap-center min-w-80 w-full max-w-full m-1 border {history.messages[ class=" snap-center min-w-80 w-full max-w-full m-1 border {history.messages[
currentMessageId currentMessageId
].model === model ].modelIdx === modelIdx
? 'border-gray-100 dark:border-gray-800 border-[1.5px]' ? 'border-gray-100 dark:border-gray-800 border-[1.5px]'
: 'border-gray-50 dark:border-gray-850 '} transition p-5 rounded-3xl" : 'border-gray-50 dark:border-gray-850 '} transition p-5 rounded-3xl"
on:click={() => { on:click={() => {
@ -131,13 +128,13 @@
> >
{#key history.currentId} {#key history.currentId}
<ResponseMessage <ResponseMessage
message={groupedMessages[model].messages[groupedMessagesIdx[model]]} message={groupedMessages[modelIdx].messages[groupedMessagesIdx[modelIdx]]}
siblings={groupedMessages[model].messages.map((m) => m.id)} siblings={groupedMessages[modelIdx].messages.map((m) => m.id)}
isLastMessage={true} isLastMessage={true}
{updateChatMessages} {updateChatMessages}
{confirmEditResponseMessage} {confirmEditResponseMessage}
showPreviousMessage={() => showPreviousMessage(model)} showPreviousMessage={() => showPreviousMessage(modelIdx)}
showNextMessage={() => showNextMessage(model)} showNextMessage={() => showNextMessage(modelIdx)}
{readOnly} {readOnly}
{rateMessage} {rateMessage}
{copyToClipboard} {copyToClipboard}
@ -145,7 +142,7 @@
regenerateResponse={async (message) => { regenerateResponse={async (message) => {
regenerateResponse(message); regenerateResponse(message);
await tick(); await tick();
groupedMessagesIdx[model] = groupedMessages[model].messages.length - 1; groupedMessagesIdx[modelIdx] = groupedMessages[modelIdx].messages.length - 1;
}} }}
on:save={async (e) => { on:save={async (e) => {
console.log('save', e); console.log('save', e);