Merge pull request #254 from ali00209/new_bolt5

fix: bug  #245
This commit is contained in:
Cole Medin 2024-11-11 18:39:48 -06:00 committed by GitHub
commit 0b75051ba5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 108 additions and 37 deletions

View File

@ -74,8 +74,14 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
const textareaRef = useRef<HTMLTextAreaElement>(null);
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [model, setModel] = useState(() => {
const savedModel = Cookies.get('selectedModel');
return savedModel || DEFAULT_MODEL;
});
const [provider, setProvider] = useState(() => {
const savedProvider = Cookies.get('selectedProvider');
return savedProvider || DEFAULT_PROVIDER;
});
const { showChat } = useStore(chatStore);
@ -216,6 +222,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
}
}, []);
const handleModelChange = (newModel: string) => {
setModel(newModel);
Cookies.set('selectedModel', newModel, { expires: 30 });
};
const handleProviderChange = (newProvider: string) => {
setProvider(newProvider);
Cookies.set('selectedProvider', newProvider, { expires: 30 });
};
return (
<BaseChat
ref={animationScope}
@ -228,9 +244,9 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
promptEnhanced={promptEnhanced}
sendMessage={sendMessage}
model={model}
setModel={setModel}
setModel={handleModelChange}
provider={provider}
setProvider={setProvider}
setProvider={handleProviderChange}
messageRef={messageRef}
scrollRef={scrollRef}
handleInputChange={handleInputChange}
@ -246,10 +262,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
};
})}
enhancePrompt={() => {
enhancePrompt(input, (input) => {
setInput(input);
scrollTextArea();
});
enhancePrompt(
input,
(input) => {
setInput(input);
scrollTextArea();
},
model,
provider,
apiKeys
);
}}
/>
);

View File

@ -12,41 +12,55 @@ export function usePromptEnhancer() {
setPromptEnhanced(false);
};
const enhancePrompt = async (input: string, setInput: (value: string) => void) => {
const enhancePrompt = async (
input: string,
setInput: (value: string) => void,
model: string,
provider: string,
apiKeys?: Record<string, string>
) => {
setEnhancingPrompt(true);
setPromptEnhanced(false);
const requestBody: any = {
message: input,
model,
provider,
};
if (apiKeys) {
requestBody.apiKeys = apiKeys;
}
const response = await fetch('/api/enhancer', {
method: 'POST',
body: JSON.stringify({
message: input,
}),
body: JSON.stringify(requestBody),
});
const reader = response.body?.getReader();
const originalInput = input;
if (reader) {
const decoder = new TextDecoder();
let _input = '';
let _error;
try {
setInput('');
while (true) {
const { value, done } = await reader.read();
if (done) {
break;
}
_input += decoder.decode(value);
logger.trace('Set input', _input);
setInput(_input);
}
} catch (error) {
@ -56,10 +70,10 @@ export function usePromptEnhancer() {
if (_error) {
logger.error(_error);
}
setEnhancingPrompt(false);
setPromptEnhanced(true);
setTimeout(() => {
setInput(_input);
});

View File

@ -2,6 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent';
import type { StreamingOptions } from '~/lib/.server/llm/stream-text';
const encoder = new TextEncoder();
const decoder = new TextDecoder();
@ -11,14 +12,34 @@ export async function action(args: ActionFunctionArgs) {
}
async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message } = await request.json<{ message: string }>();
const { message, model, provider, apiKeys } = await request.json<{
message: string;
model: string;
provider: string;
apiKeys?: Record<string, string>;
}>();
// Validate 'model' and 'provider' fields
if (!model || typeof model !== 'string') {
throw new Response('Invalid or missing model', {
status: 400,
statusText: 'Bad Request'
});
}
if (!provider || typeof provider !== 'string') {
throw new Response('Invalid or missing provider', {
status: 400,
statusText: 'Bad Request'
});
}
try {
const result = await streamText(
[
{
role: 'user',
content: stripIndents`
content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents`
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
IMPORTANT: Only respond with the improved prompt and nothing else!
@ -30,28 +51,42 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
},
],
context.cloudflare.env,
undefined,
apiKeys
);
const transformStream = new TransformStream({
transform(chunk, controller) {
const processedChunk = decoder
.decode(chunk)
.split('\n')
.filter((line) => line !== '')
.map(parseStreamPart)
.map((part) => part.value)
.join('');
controller.enqueue(encoder.encode(processedChunk));
const text = decoder.decode(chunk);
const lines = text.split('\n').filter(line => line.trim() !== '');
for (const line of lines) {
try {
const parsed = parseStreamPart(line);
if (parsed.type === 'text') {
controller.enqueue(encoder.encode(parsed.value));
}
} catch (e) {
// Skip invalid JSON lines
console.warn('Failed to parse stream part:', line);
}
}
},
});
const transformedStream = result.toAIStream().pipeThrough(transformStream);
return new StreamingTextResponse(transformedStream);
} catch (error) {
} catch (error: unknown) {
console.log(error);
if (error instanceof Error && error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', {
status: 401,
statusText: 'Unauthorized'
});
}
throw new Response(null, {
status: 500,
statusText: 'Internal Server Error',