From f01e43d798833813f2b5f78f4c3181fc502ab072 Mon Sep 17 00:00:00 2001 From: Daniel Woelfel Date: Wed, 21 May 2025 13:10:09 -0700 Subject: [PATCH] Enforce max response segments --- app/components/chat/Chat.client.tsx | 24 +++++++++++++++++- app/components/chat/ProgressCompilation.tsx | 2 ++ app/routes/api.chat.ts | 28 +++++++++++++++------ app/types/context.ts | 8 +++++- 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 1c6eca2f..fc206044 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -27,6 +27,7 @@ import { logStore } from '~/lib/stores/logs'; import { streamingState } from '~/lib/stores/streaming'; import { filesToArtifacts } from '~/utils/fileUtils'; import { supabaseConnection } from '~/lib/stores/supabase'; +import type { DataStreamError } from '~/types/context'; const toastAnimation = cssTransition({ enter: 'animated fadeInRight', @@ -148,6 +149,9 @@ export const ChatImpl = memo( const [apiKeys, setApiKeys] = useState>({}); + // Keep track of the errors we alerted on. useChat gets the same data twice even if they're removed with setData + const alertedErrorIds = useRef(new Set()); + const { messages, isLoading, @@ -191,7 +195,10 @@ export const ChatImpl = memo( }, onFinish: (message, response) => { const usage = response.usage; - setData(undefined); + setData(() => { + alertedErrorIds.current.clear(); + return undefined; + }); if (usage) { console.log('Token usage:', usage); @@ -230,6 +237,21 @@ export const ChatImpl = memo( } }, [model, provider, searchParams]); + useEffect(() => { + if (chatData) { + for (const data of chatData) { + if (data && typeof data === 'object' && 'type' in data && data.type === 'error') { + const error = data as DataStreamError; + + if (!alertedErrorIds.current.has(error.id)) { + toast.error('There was an error processing your request: ' + error.message); + alertedErrorIds.current.add(error.id); + } + } + } + } + }, [chatData]); + const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } = usePromptEnhancer(); const { parsedMessages, parseMessages } = useMessageParser(); diff --git a/app/components/chat/ProgressCompilation.tsx b/app/components/chat/ProgressCompilation.tsx index 68ae3388..aca193db 100644 --- a/app/components/chat/ProgressCompilation.tsx +++ b/app/components/chat/ProgressCompilation.tsx @@ -100,6 +100,8 @@ const ProgressItem = ({ progress }: { progress: ProgressAnnotation }) => {
) : progress.status === 'complete' ? (
+ ) : progress.status === 'error' ? ( +
) : null} {/* {x.label} */} diff --git a/app/routes/api.chat.ts b/app/routes/api.chat.ts index 5917dfc4..6ed2887c 100644 --- a/app/routes/api.chat.ts +++ b/app/routes/api.chat.ts @@ -3,11 +3,10 @@ import { createDataStream, generateId } from 'ai'; import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS, type FileMap } from '~/lib/.server/llm/constants'; import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts'; import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; -import SwitchableStream from '~/lib/.server/llm/switchable-stream'; import type { IProviderSetting } from '~/types/model'; import { createScopedLogger } from '~/utils/logger'; import { getFilePaths, selectContext } from '~/lib/.server/llm/select-context'; -import type { ContextAnnotation, ProgressAnnotation } from '~/types/context'; +import type { ContextAnnotation, DataStreamError, ProgressAnnotation } from '~/types/context'; import { WORK_DIR } from '~/utils/constants'; import { createSummary } from '~/lib/.server/llm/create-summary'; import { extractPropertiesFromMessage } from '~/lib/.server/llm/utils'; @@ -58,7 +57,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) { parseCookies(cookieHeader || '').providers || '{}', ); - const stream = new SwitchableStream(); + let responseSegments = 0; const cumulativeUsage = { completionTokens: 0, @@ -217,15 +216,30 @@ async function chatAction({ context, request }: ActionFunctionArgs) { } satisfies ProgressAnnotation); await new Promise((resolve) => setTimeout(resolve, 0)); - // stream.close(); return; } - if (stream.switches >= MAX_RESPONSE_SEGMENTS) { - throw Error('Cannot continue message: Maximum segments reached'); + responseSegments++; + + if (responseSegments >= MAX_RESPONSE_SEGMENTS) { + dataStream.writeData({ + type: 'error', + id: generateId(), + message: 'Cannot continue message: Maximum segments reached.', + } satisfies DataStreamError); + dataStream.writeData({ + type: 'progress', + label: 'summary', + status: 'error', + order: progressCounter++, + message: 'Error: maximum segments reached.', + } satisfies ProgressAnnotation); + await new Promise((resolve) => setTimeout(resolve, 0)); + + return; } - const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches; + const switchesLeft = MAX_RESPONSE_SEGMENTS - responseSegments; logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`); diff --git a/app/types/context.ts b/app/types/context.ts index 75c21db7..7fd63a26 100644 --- a/app/types/context.ts +++ b/app/types/context.ts @@ -12,7 +12,13 @@ export type ContextAnnotation = export type ProgressAnnotation = { type: 'progress'; label: string; - status: 'in-progress' | 'complete'; + status: 'in-progress' | 'complete' | 'error'; order: number; message: string; }; + +export type DataStreamError = { + type: 'error'; + id: string; + message: string; +};