updated implementation

This commit is contained in:
Anirban Kar 2024-12-16 19:47:18 +05:30
parent 070e911be1
commit 3b8d251a55
6 changed files with 39 additions and 24 deletions

View File

@ -1 +1 @@
{ "commit": "6ba93974a02a98c83badf2f0002ff4812b8f75a9" } { "commit": "070e911be17e1e1f3994220c3ed89b0060c67bd2" }

View File

@ -1,15 +1,22 @@
import { memo } from 'react'; import { memo } from 'react';
import { Markdown } from './Markdown'; import { Markdown } from './Markdown';
import { USAGE_REGEX } from '~/utils/constants'; import type { JSONValue } from 'ai';
interface AssistantMessageProps { interface AssistantMessageProps {
content: string; content: string;
annotations?: JSONValue[];
} }
export const AssistantMessage = memo(({ content }: AssistantMessageProps) => { export const AssistantMessage = memo(({ content, annotations }: AssistantMessageProps) => {
const match = content.match(USAGE_REGEX); const filteredAnnotations = (annotations?.filter(
const usage = match ? JSON.parse(match[1]) : null; (annotation: JSONValue) => annotation && typeof annotation === 'object' && Object.keys(annotation).includes('type'),
const cleanContent = content.replace(USAGE_REGEX, '').trim(); ) || []) as { type: string; value: any }[];
const usage: {
completionTokens: number;
promptTokens: number;
totalTokens: number;
} = filteredAnnotations.find((annotation) => annotation.type === 'usage')?.value;
return ( return (
<div className="overflow-hidden w-full"> <div className="overflow-hidden w-full">
@ -18,7 +25,7 @@ export const AssistantMessage = memo(({ content }: AssistantMessageProps) => {
Tokens: {usage.totalTokens} (prompt: {usage.promptTokens}, completion: {usage.completionTokens}) Tokens: {usage.totalTokens} (prompt: {usage.promptTokens}, completion: {usage.completionTokens})
</div> </div>
)} )}
<Markdown html>{cleanContent}</Markdown> <Markdown html>{content}</Markdown>
</div> </div>
); );
}); });

View File

@ -65,7 +65,11 @@ export const Messages = React.forwardRef<HTMLDivElement, MessagesProps>((props:
</div> </div>
)} )}
<div className="grid grid-col-1 w-full"> <div className="grid grid-col-1 w-full">
{isUserMessage ? <UserMessage content={content} /> : <AssistantMessage content={content} />} {isUserMessage ? (
<UserMessage content={content} />
) : (
<AssistantMessage content={content} annotations={message.annotations} />
)}
</div> </div>
{!isUserMessage && ( {!isUserMessage && (
<div className="flex gap-2 flex-col lg:flex-row"> <div className="flex gap-2 flex-col lg:flex-row">

View File

@ -1,5 +1,5 @@
export default class SwitchableStream extends TransformStream { export default class SwitchableStream extends TransformStream {
_controller: TransformStreamDefaultController | null = null; private _controller: TransformStreamDefaultController | null = null;
private _currentReader: ReadableStreamDefaultReader | null = null; private _currentReader: ReadableStreamDefaultReader | null = null;
private _switches = 0; private _switches = 0;

View File

@ -1,4 +1,5 @@
import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { createDataStream } from 'ai';
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts'; import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
@ -53,26 +54,30 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
onFinish: async ({ text: content, finishReason, usage }) => { onFinish: async ({ text: content, finishReason, usage }) => {
console.log('usage', usage); console.log('usage', usage);
if (usage && stream._controller) { if (usage) {
cumulativeUsage.completionTokens += usage.completionTokens || 0; cumulativeUsage.completionTokens += usage.completionTokens || 0;
cumulativeUsage.promptTokens += usage.promptTokens || 0; cumulativeUsage.promptTokens += usage.promptTokens || 0;
cumulativeUsage.totalTokens += usage.totalTokens || 0; cumulativeUsage.totalTokens += usage.totalTokens || 0;
// Send usage info in message metadata for assistant messages
const usageMetadata = `0:"[Usage: ${JSON.stringify({
completionTokens: cumulativeUsage.completionTokens,
promptTokens: cumulativeUsage.promptTokens,
totalTokens: cumulativeUsage.totalTokens,
})}\n]"`;
console.log(usageMetadata);
const encodedData = new TextEncoder().encode(usageMetadata);
stream._controller.enqueue(encodedData);
} }
if (finishReason !== 'length') { if (finishReason !== 'length') {
return stream.close(); return stream
.switchSource(
createDataStream({
async execute(dataStream) {
dataStream.writeMessageAnnotation({
type: 'usage',
value: {
completionTokens: cumulativeUsage.completionTokens,
promptTokens: cumulativeUsage.promptTokens,
totalTokens: cumulativeUsage.totalTokens,
},
});
},
onError: (error: any) => `Custom error: ${error.message}`,
}),
)
.then(() => stream.close());
} }
if (stream.switches >= MAX_RESPONSE_SEGMENTS) { if (stream.switches >= MAX_RESPONSE_SEGMENTS) {

View File

@ -9,7 +9,6 @@ export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications'; export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications';
export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/; export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/;
export const PROVIDER_REGEX = /\[Provider: (.*?)\]\n\n/; export const PROVIDER_REGEX = /\[Provider: (.*?)\]\n\n/;
export const USAGE_REGEX = /\[Usage: ({.*?})\]/; // Keep this regex for assistant messages
export const DEFAULT_MODEL = 'claude-3-5-sonnet-latest'; export const DEFAULT_MODEL = 'claude-3-5-sonnet-latest';
export const PROMPT_COOKIE_KEY = 'cachedPrompt'; export const PROMPT_COOKIE_KEY = 'cachedPrompt';