Enforce max response segments

This commit is contained in:
Daniel Woelfel 2025-05-21 13:10:09 -07:00
parent f0aa58c922
commit f01e43d798
No known key found for this signature in database
GPG Key ID: 52E90E357B942D32
4 changed files with 53 additions and 9 deletions

View File

@ -27,6 +27,7 @@ import { logStore } from '~/lib/stores/logs';
import { streamingState } from '~/lib/stores/streaming';
import { filesToArtifacts } from '~/utils/fileUtils';
import { supabaseConnection } from '~/lib/stores/supabase';
import type { DataStreamError } from '~/types/context';
const toastAnimation = cssTransition({
enter: 'animated fadeInRight',
@ -148,6 +149,9 @@ export const ChatImpl = memo(
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
// Keep track of the errors we alerted on. useChat gets the same data twice even if they're removed with setData
const alertedErrorIds = useRef(new Set());
const {
messages,
isLoading,
@ -191,7 +195,10 @@ export const ChatImpl = memo(
},
onFinish: (message, response) => {
const usage = response.usage;
setData(undefined);
setData(() => {
alertedErrorIds.current.clear();
return undefined;
});
if (usage) {
console.log('Token usage:', usage);
@ -230,6 +237,21 @@ export const ChatImpl = memo(
}
}, [model, provider, searchParams]);
useEffect(() => {
if (chatData) {
for (const data of chatData) {
if (data && typeof data === 'object' && 'type' in data && data.type === 'error') {
const error = data as DataStreamError;
if (!alertedErrorIds.current.has(error.id)) {
toast.error('There was an error processing your request: ' + error.message);
alertedErrorIds.current.add(error.id);
}
}
}
}
}, [chatData]);
const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } = usePromptEnhancer();
const { parsedMessages, parseMessages } = useMessageParser();

View File

@ -100,6 +100,8 @@ const ProgressItem = ({ progress }: { progress: ProgressAnnotation }) => {
<div className="i-svg-spinners:90-ring-with-bg"></div>
) : progress.status === 'complete' ? (
<div className="i-ph:check"></div>
) : progress.status === 'error' ? (
<div className="i-ph:warning"></div>
) : null}
</div>
{/* {x.label} */}

View File

@ -3,11 +3,10 @@ import { createDataStream, generateId } from 'ai';
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS, type FileMap } from '~/lib/.server/llm/constants';
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts';
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
import SwitchableStream from '~/lib/.server/llm/switchable-stream';
import type { IProviderSetting } from '~/types/model';
import { createScopedLogger } from '~/utils/logger';
import { getFilePaths, selectContext } from '~/lib/.server/llm/select-context';
import type { ContextAnnotation, ProgressAnnotation } from '~/types/context';
import type { ContextAnnotation, DataStreamError, ProgressAnnotation } from '~/types/context';
import { WORK_DIR } from '~/utils/constants';
import { createSummary } from '~/lib/.server/llm/create-summary';
import { extractPropertiesFromMessage } from '~/lib/.server/llm/utils';
@ -58,7 +57,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
parseCookies(cookieHeader || '').providers || '{}',
);
const stream = new SwitchableStream();
let responseSegments = 0;
const cumulativeUsage = {
completionTokens: 0,
@ -217,15 +216,30 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
} satisfies ProgressAnnotation);
await new Promise((resolve) => setTimeout(resolve, 0));
// stream.close();
return;
}
if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
throw Error('Cannot continue message: Maximum segments reached');
responseSegments++;
if (responseSegments >= MAX_RESPONSE_SEGMENTS) {
dataStream.writeData({
type: 'error',
id: generateId(),
message: 'Cannot continue message: Maximum segments reached.',
} satisfies DataStreamError);
dataStream.writeData({
type: 'progress',
label: 'summary',
status: 'error',
order: progressCounter++,
message: 'Error: maximum segments reached.',
} satisfies ProgressAnnotation);
await new Promise((resolve) => setTimeout(resolve, 0));
return;
}
const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
const switchesLeft = MAX_RESPONSE_SEGMENTS - responseSegments;
logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`);

View File

@ -12,7 +12,13 @@ export type ContextAnnotation =
export type ProgressAnnotation = {
type: 'progress';
label: string;
status: 'in-progress' | 'complete';
status: 'in-progress' | 'complete' | 'error';
order: number;
message: string;
};
export type DataStreamError = {
type: 'error';
id: string;
message: string;
};