diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index c88964a2..14c8e737 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -315,6 +315,9 @@ export const ChatImpl = memo( runAnimation(); + const existingRepositoryId = getMessagesRepositoryId(messages); + let updatedRepository = false; + const addResponseMessage = (msg: Message) => { if (gNumAborts != numAbortsAtStart) { return; @@ -337,6 +340,13 @@ export const ChatImpl = memo( } setMessages(newMessages); + + // Update the repository as soon as it has changed. + const responseRepositoryId = getMessagesRepositoryId(newMessages); + if (responseRepositoryId && existingRepositoryId != responseRepositoryId) { + simulationRepositoryUpdated(responseRepositoryId); + updatedRepository = true; + } }; const references: ChatReference[] = []; @@ -375,12 +385,7 @@ export const ChatImpl = memo( textareaRef.current?.blur(); - const existingRepositoryId = getMessagesRepositoryId(messages); - const responseRepositoryId = getMessagesRepositoryId(newMessages); - - if (responseRepositoryId && existingRepositoryId != responseRepositoryId) { - simulationRepositoryUpdated(responseRepositoryId); - + if (updatedRepository) { const lastMessage = newMessages[newMessages.length - 1]; setApproveChangesMessageId(lastMessage.id); } else { @@ -457,8 +462,18 @@ export const ChatImpl = memo( setApproveChangesMessageId(undefined); const message = messages.find((message) => message.id === messageId); + assert(message, 'Message not found'); + assert(message == messages[messages.length - 1], 'Message must be the last message'); - await onRewind(messageId); + // Erase all messages since the last user message. + let rewindMessageId = message.id; + for (let i = messages.length - 2; i >= 0; i--) { + if (messages[i].role == 'user') { + break; + } + rewindMessageId = messages[i].id; + } + await onRewind(rewindMessageId); let shareProjectSuccess = false; @@ -466,7 +481,6 @@ export const ChatImpl = memo( const feedbackData: any = { explanation: data.explanation, chatMessages: messages, - repositoryId: message?.repositoryId, loginKey: getNutLoginKey(), }; diff --git a/app/components/sidebar/SaveReproduction.tsx b/app/components/sidebar/SaveReproduction.tsx index 6f806068..c8053657 100644 --- a/app/components/sidebar/SaveReproduction.tsx +++ b/app/components/sidebar/SaveReproduction.tsx @@ -8,6 +8,7 @@ import { getLastUserSimulationData, getLastSimulationChatMessages, isSimulatingOrHasFinished, + getLastSimulationChatReferences, } from '~/lib/replay/SimulationPrompt'; ReactModal.setAppElement('#root'); @@ -66,13 +67,14 @@ export function SaveReproductionModal() { } const messages = getLastSimulationChatMessages(); + const references = getLastSimulationChatReferences(); if (!messages) { toast.error('No user prompt found'); return; } - const reproData = { simulationData, messages }; + const reproData = { simulationData, messages, references }; /** * TODO: Split `solution` into `reproData` and `evaluator`. diff --git a/app/lib/replay/ReplayProtocolClient.ts b/app/lib/replay/ReplayProtocolClient.ts index b3fb8402..68b73c4a 100644 --- a/app/lib/replay/ReplayProtocolClient.ts +++ b/app/lib/replay/ReplayProtocolClient.ts @@ -113,6 +113,11 @@ export class ProtocolClient { close() { this.socket.close(); + + for (const info of this.pendingCommands.values()) { + info.deferred.reject(new Error('Client destroyed')); + } + this.pendingCommands.clear(); } listenForMessage(method: string, callback: (params: any) => void) { diff --git a/app/lib/replay/SimulationPrompt.ts b/app/lib/replay/SimulationPrompt.ts index b95fe70b..f55d1d16 100644 --- a/app/lib/replay/SimulationPrompt.ts +++ b/app/lib/replay/SimulationPrompt.ts @@ -46,6 +46,10 @@ class ChatManager { // Simulation data for the page itself and any user interactions. pageData: SimulationData = []; + // State to ensure that the chat manager is not destroyed until all messages finish. + private pendingMessages = 0; + private mustDestroyAfterChatFinishes = false; + constructor() { this.client = new ProtocolClient(); this.chatIdPromise = (async () => { @@ -65,11 +69,19 @@ class ChatManager { return !!this.client; } - destroy() { + private destroy() { this.client?.close(); this.client = undefined; } + destroyAfterChatFinishes() { + if (this.pendingMessages == 0) { + this.destroy(); + } else { + this.mustDestroyAfterChatFinishes = true; + } + } + async setRepositoryId(repositoryId: string) { assert(this.client, 'Chat has been destroyed'); this.repositoryId = repositoryId; @@ -126,6 +138,8 @@ class ChatManager { async sendChatMessage(messages: Message[], references: ChatReference[], onResponsePart: ChatResponsePartCallback) { assert(this.client, 'Chat has been destroyed'); + this.pendingMessages++; + const responseId = `response-${generateRandomId()}`; const removeResponseListener = this.client.listenForMessage( @@ -147,7 +161,13 @@ class ChatManager { params: { chatId, responseId, messages, references }, }); + console.log('ChatMessageFinished', new Date().toISOString(), chatId); + removeResponseListener(); + + if (--this.pendingMessages == 0 && this.mustDestroyAfterChatFinishes) { + this.destroy(); + } } } @@ -155,8 +175,11 @@ class ChatManager { let gChatManager: ChatManager | undefined; function startChat(repositoryId: string | undefined, pageData: SimulationData) { + // Any existing chat manager won't be used anymore for new messages, but it will + // not close until its messages actually finish and any future repository updates + // occur. if (gChatManager) { - gChatManager.destroy(); + gChatManager.destroyAfterChatFinishes(); } gChatManager = new ChatManager(); @@ -206,12 +229,15 @@ export function simulationAddData(data: SimulationData) { gChatManager.addPageData(data); } -export function simulationFinishData() { - gChatManager?.finishSimulationData(); -} - let gLastUserSimulationData: SimulationData | undefined; +export function simulationFinishData() { + if (gChatManager) { + gChatManager.finishSimulationData(); + gLastUserSimulationData = [...gChatManager.pageData]; + } +} + export function getLastUserSimulationData(): SimulationData | undefined { return gLastUserSimulationData; } @@ -226,6 +252,12 @@ export function getLastSimulationChatMessages(): Message[] | undefined { return gLastSimulationChatMessages; } +let gLastSimulationChatReferences: ChatReference[] | undefined; + +export function getLastSimulationChatReferences(): ChatReference[] | undefined { + return gLastSimulationChatReferences; +} + export async function sendChatMessage( messages: Message[], references: ChatReference[], @@ -235,5 +267,8 @@ export async function sendChatMessage( gChatManager = new ChatManager(); } + gLastSimulationChatMessages = messages; + gLastSimulationChatReferences = references; + await gChatManager.sendChatMessage(messages, references, onResponsePart); }