Merge pull request #254 from ali00209/new_bolt5

fix: bug  #245
This commit is contained in:
Cole Medin 2024-11-11 18:39:48 -06:00 committed by GitHub
commit 0b75051ba5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 108 additions and 37 deletions

View File

@ -74,8 +74,14 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
const textareaRef = useRef<HTMLTextAreaElement>(null); const textareaRef = useRef<HTMLTextAreaElement>(null);
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL); const [model, setModel] = useState(() => {
const [provider, setProvider] = useState(DEFAULT_PROVIDER); const savedModel = Cookies.get('selectedModel');
return savedModel || DEFAULT_MODEL;
});
const [provider, setProvider] = useState(() => {
const savedProvider = Cookies.get('selectedProvider');
return savedProvider || DEFAULT_PROVIDER;
});
const { showChat } = useStore(chatStore); const { showChat } = useStore(chatStore);
@ -216,6 +222,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
} }
}, []); }, []);
const handleModelChange = (newModel: string) => {
setModel(newModel);
Cookies.set('selectedModel', newModel, { expires: 30 });
};
const handleProviderChange = (newProvider: string) => {
setProvider(newProvider);
Cookies.set('selectedProvider', newProvider, { expires: 30 });
};
return ( return (
<BaseChat <BaseChat
ref={animationScope} ref={animationScope}
@ -228,9 +244,9 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
promptEnhanced={promptEnhanced} promptEnhanced={promptEnhanced}
sendMessage={sendMessage} sendMessage={sendMessage}
model={model} model={model}
setModel={setModel} setModel={handleModelChange}
provider={provider} provider={provider}
setProvider={setProvider} setProvider={handleProviderChange}
messageRef={messageRef} messageRef={messageRef}
scrollRef={scrollRef} scrollRef={scrollRef}
handleInputChange={handleInputChange} handleInputChange={handleInputChange}
@ -246,10 +262,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
}; };
})} })}
enhancePrompt={() => { enhancePrompt={() => {
enhancePrompt(input, (input) => { enhancePrompt(
setInput(input); input,
scrollTextArea(); (input) => {
}); setInput(input);
scrollTextArea();
},
model,
provider,
apiKeys
);
}} }}
/> />
); );

View File

@ -12,15 +12,29 @@ export function usePromptEnhancer() {
setPromptEnhanced(false); setPromptEnhanced(false);
}; };
const enhancePrompt = async (input: string, setInput: (value: string) => void) => { const enhancePrompt = async (
input: string,
setInput: (value: string) => void,
model: string,
provider: string,
apiKeys?: Record<string, string>
) => {
setEnhancingPrompt(true); setEnhancingPrompt(true);
setPromptEnhanced(false); setPromptEnhanced(false);
const requestBody: any = {
message: input,
model,
provider,
};
if (apiKeys) {
requestBody.apiKeys = apiKeys;
}
const response = await fetch('/api/enhancer', { const response = await fetch('/api/enhancer', {
method: 'POST', method: 'POST',
body: JSON.stringify({ body: JSON.stringify(requestBody),
message: input,
}),
}); });
const reader = response.body?.getReader(); const reader = response.body?.getReader();

View File

@ -2,6 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { StreamingTextResponse, parseStreamPart } from 'ai'; import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text'; import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent'; import { stripIndents } from '~/utils/stripIndent';
import type { StreamingOptions } from '~/lib/.server/llm/stream-text';
const encoder = new TextEncoder(); const encoder = new TextEncoder();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
@ -11,14 +12,34 @@ export async function action(args: ActionFunctionArgs) {
} }
async function enhancerAction({ context, request }: ActionFunctionArgs) { async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message } = await request.json<{ message: string }>(); const { message, model, provider, apiKeys } = await request.json<{
message: string;
model: string;
provider: string;
apiKeys?: Record<string, string>;
}>();
// Validate 'model' and 'provider' fields
if (!model || typeof model !== 'string') {
throw new Response('Invalid or missing model', {
status: 400,
statusText: 'Bad Request'
});
}
if (!provider || typeof provider !== 'string') {
throw new Response('Invalid or missing provider', {
status: 400,
statusText: 'Bad Request'
});
}
try { try {
const result = await streamText( const result = await streamText(
[ [
{ {
role: 'user', role: 'user',
content: stripIndents` content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents`
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags. I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
IMPORTANT: Only respond with the improved prompt and nothing else! IMPORTANT: Only respond with the improved prompt and nothing else!
@ -30,28 +51,42 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
}, },
], ],
context.cloudflare.env, context.cloudflare.env,
undefined,
apiKeys
); );
const transformStream = new TransformStream({ const transformStream = new TransformStream({
transform(chunk, controller) { transform(chunk, controller) {
const processedChunk = decoder const text = decoder.decode(chunk);
.decode(chunk) const lines = text.split('\n').filter(line => line.trim() !== '');
.split('\n')
.filter((line) => line !== '')
.map(parseStreamPart)
.map((part) => part.value)
.join('');
controller.enqueue(encoder.encode(processedChunk)); for (const line of lines) {
try {
const parsed = parseStreamPart(line);
if (parsed.type === 'text') {
controller.enqueue(encoder.encode(parsed.value));
}
} catch (e) {
// Skip invalid JSON lines
console.warn('Failed to parse stream part:', line);
}
}
}, },
}); });
const transformedStream = result.toAIStream().pipeThrough(transformStream); const transformedStream = result.toAIStream().pipeThrough(transformStream);
return new StreamingTextResponse(transformedStream); return new StreamingTextResponse(transformedStream);
} catch (error) { } catch (error: unknown) {
console.log(error); console.log(error);
if (error instanceof Error && error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', {
status: 401,
statusText: 'Unauthorized'
});
}
throw new Response(null, { throw new Response(null, {
status: 500, status: 500,
statusText: 'Internal Server Error', statusText: 'Internal Server Error',