mirror of
https://github.com/stackblitz-labs/bolt.diy
synced 2025-06-26 18:26:38 +00:00
Enforce max response segments
This commit is contained in:
parent
f0aa58c922
commit
f01e43d798
@ -27,6 +27,7 @@ import { logStore } from '~/lib/stores/logs';
|
||||
import { streamingState } from '~/lib/stores/streaming';
|
||||
import { filesToArtifacts } from '~/utils/fileUtils';
|
||||
import { supabaseConnection } from '~/lib/stores/supabase';
|
||||
import type { DataStreamError } from '~/types/context';
|
||||
|
||||
const toastAnimation = cssTransition({
|
||||
enter: 'animated fadeInRight',
|
||||
@ -148,6 +149,9 @@ export const ChatImpl = memo(
|
||||
|
||||
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
|
||||
|
||||
// Keep track of the errors we alerted on. useChat gets the same data twice even if they're removed with setData
|
||||
const alertedErrorIds = useRef(new Set());
|
||||
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
@ -191,7 +195,10 @@ export const ChatImpl = memo(
|
||||
},
|
||||
onFinish: (message, response) => {
|
||||
const usage = response.usage;
|
||||
setData(undefined);
|
||||
setData(() => {
|
||||
alertedErrorIds.current.clear();
|
||||
return undefined;
|
||||
});
|
||||
|
||||
if (usage) {
|
||||
console.log('Token usage:', usage);
|
||||
@ -230,6 +237,21 @@ export const ChatImpl = memo(
|
||||
}
|
||||
}, [model, provider, searchParams]);
|
||||
|
||||
useEffect(() => {
|
||||
if (chatData) {
|
||||
for (const data of chatData) {
|
||||
if (data && typeof data === 'object' && 'type' in data && data.type === 'error') {
|
||||
const error = data as DataStreamError;
|
||||
|
||||
if (!alertedErrorIds.current.has(error.id)) {
|
||||
toast.error('There was an error processing your request: ' + error.message);
|
||||
alertedErrorIds.current.add(error.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [chatData]);
|
||||
|
||||
const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } = usePromptEnhancer();
|
||||
const { parsedMessages, parseMessages } = useMessageParser();
|
||||
|
||||
|
@ -100,6 +100,8 @@ const ProgressItem = ({ progress }: { progress: ProgressAnnotation }) => {
|
||||
<div className="i-svg-spinners:90-ring-with-bg"></div>
|
||||
) : progress.status === 'complete' ? (
|
||||
<div className="i-ph:check"></div>
|
||||
) : progress.status === 'error' ? (
|
||||
<div className="i-ph:warning"></div>
|
||||
) : null}
|
||||
</div>
|
||||
{/* {x.label} */}
|
||||
|
@ -3,11 +3,10 @@ import { createDataStream, generateId } from 'ai';
|
||||
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS, type FileMap } from '~/lib/.server/llm/constants';
|
||||
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts';
|
||||
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
|
||||
import SwitchableStream from '~/lib/.server/llm/switchable-stream';
|
||||
import type { IProviderSetting } from '~/types/model';
|
||||
import { createScopedLogger } from '~/utils/logger';
|
||||
import { getFilePaths, selectContext } from '~/lib/.server/llm/select-context';
|
||||
import type { ContextAnnotation, ProgressAnnotation } from '~/types/context';
|
||||
import type { ContextAnnotation, DataStreamError, ProgressAnnotation } from '~/types/context';
|
||||
import { WORK_DIR } from '~/utils/constants';
|
||||
import { createSummary } from '~/lib/.server/llm/create-summary';
|
||||
import { extractPropertiesFromMessage } from '~/lib/.server/llm/utils';
|
||||
@ -58,7 +57,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
parseCookies(cookieHeader || '').providers || '{}',
|
||||
);
|
||||
|
||||
const stream = new SwitchableStream();
|
||||
let responseSegments = 0;
|
||||
|
||||
const cumulativeUsage = {
|
||||
completionTokens: 0,
|
||||
@ -217,15 +216,30 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
} satisfies ProgressAnnotation);
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
// stream.close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
|
||||
throw Error('Cannot continue message: Maximum segments reached');
|
||||
responseSegments++;
|
||||
|
||||
if (responseSegments >= MAX_RESPONSE_SEGMENTS) {
|
||||
dataStream.writeData({
|
||||
type: 'error',
|
||||
id: generateId(),
|
||||
message: 'Cannot continue message: Maximum segments reached.',
|
||||
} satisfies DataStreamError);
|
||||
dataStream.writeData({
|
||||
type: 'progress',
|
||||
label: 'summary',
|
||||
status: 'error',
|
||||
order: progressCounter++,
|
||||
message: 'Error: maximum segments reached.',
|
||||
} satisfies ProgressAnnotation);
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
|
||||
const switchesLeft = MAX_RESPONSE_SEGMENTS - responseSegments;
|
||||
|
||||
logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`);
|
||||
|
||||
|
@ -12,7 +12,13 @@ export type ContextAnnotation =
|
||||
export type ProgressAnnotation = {
|
||||
type: 'progress';
|
||||
label: string;
|
||||
status: 'in-progress' | 'complete';
|
||||
status: 'in-progress' | 'complete' | 'error';
|
||||
order: number;
|
||||
message: string;
|
||||
};
|
||||
|
||||
export type DataStreamError = {
|
||||
type: 'error';
|
||||
id: string;
|
||||
message: string;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user