From 9c848802924f4820d063504663104b952f4559fb Mon Sep 17 00:00:00 2001 From: ali00209 Date: Tue, 12 Nov 2024 05:10:54 +0500 Subject: [PATCH] fix: bug #245 --- app/components/chat/Chat.client.tsx | 38 +++++++++++++++---- app/lib/hooks/usePromptEnhancer.ts | 48 ++++++++++++++--------- app/routes/api.enhancer.ts | 59 +++++++++++++++++++++++------ 3 files changed, 108 insertions(+), 37 deletions(-) diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 102c4c2..a8f94f0 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -74,8 +74,14 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp const textareaRef = useRef(null); const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); - const [model, setModel] = useState(DEFAULT_MODEL); - const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const [model, setModel] = useState(() => { + const savedModel = Cookies.get('selectedModel'); + return savedModel || DEFAULT_MODEL; + }); + const [provider, setProvider] = useState(() => { + const savedProvider = Cookies.get('selectedProvider'); + return savedProvider || DEFAULT_PROVIDER; + }); const { showChat } = useStore(chatStore); @@ -216,6 +222,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp } }, []); + const handleModelChange = (newModel: string) => { + setModel(newModel); + Cookies.set('selectedModel', newModel, { expires: 30 }); + }; + + const handleProviderChange = (newProvider: string) => { + setProvider(newProvider); + Cookies.set('selectedProvider', newProvider, { expires: 30 }); + }; + return ( { - enhancePrompt(input, (input) => { - setInput(input); - scrollTextArea(); - }); + enhancePrompt( + input, + (input) => { + setInput(input); + scrollTextArea(); + }, + model, + provider, + apiKeys + ); }} /> ); diff --git a/app/lib/hooks/usePromptEnhancer.ts b/app/lib/hooks/usePromptEnhancer.ts index f376cc0..ee44999 100644 --- a/app/lib/hooks/usePromptEnhancer.ts +++ b/app/lib/hooks/usePromptEnhancer.ts @@ -12,41 +12,55 @@ export function usePromptEnhancer() { setPromptEnhanced(false); }; - const enhancePrompt = async (input: string, setInput: (value: string) => void) => { + const enhancePrompt = async ( + input: string, + setInput: (value: string) => void, + model: string, + provider: string, + apiKeys?: Record + ) => { setEnhancingPrompt(true); setPromptEnhanced(false); - + + const requestBody: any = { + message: input, + model, + provider, + }; + + if (apiKeys) { + requestBody.apiKeys = apiKeys; + } + const response = await fetch('/api/enhancer', { method: 'POST', - body: JSON.stringify({ - message: input, - }), + body: JSON.stringify(requestBody), }); - + const reader = response.body?.getReader(); - + const originalInput = input; - + if (reader) { const decoder = new TextDecoder(); - + let _input = ''; let _error; - + try { setInput(''); - + while (true) { const { value, done } = await reader.read(); - + if (done) { break; } - + _input += decoder.decode(value); - + logger.trace('Set input', _input); - + setInput(_input); } } catch (error) { @@ -56,10 +70,10 @@ export function usePromptEnhancer() { if (_error) { logger.error(_error); } - + setEnhancingPrompt(false); setPromptEnhanced(true); - + setTimeout(() => { setInput(_input); }); diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts index 5c8175c..7040b89 100644 --- a/app/routes/api.enhancer.ts +++ b/app/routes/api.enhancer.ts @@ -2,6 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { StreamingTextResponse, parseStreamPart } from 'ai'; import { streamText } from '~/lib/.server/llm/stream-text'; import { stripIndents } from '~/utils/stripIndent'; +import type { StreamingOptions } from '~/lib/.server/llm/stream-text'; const encoder = new TextEncoder(); const decoder = new TextDecoder(); @@ -11,14 +12,34 @@ export async function action(args: ActionFunctionArgs) { } async function enhancerAction({ context, request }: ActionFunctionArgs) { - const { message } = await request.json<{ message: string }>(); + const { message, model, provider, apiKeys } = await request.json<{ + message: string; + model: string; + provider: string; + apiKeys?: Record; + }>(); + + // Validate 'model' and 'provider' fields + if (!model || typeof model !== 'string') { + throw new Response('Invalid or missing model', { + status: 400, + statusText: 'Bad Request' + }); + } + + if (!provider || typeof provider !== 'string') { + throw new Response('Invalid or missing provider', { + status: 400, + statusText: 'Bad Request' + }); + } try { const result = await streamText( [ { role: 'user', - content: stripIndents` + content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents` I want you to improve the user prompt that is wrapped in \`\` tags. IMPORTANT: Only respond with the improved prompt and nothing else! @@ -30,28 +51,42 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { }, ], context.cloudflare.env, + undefined, + apiKeys ); const transformStream = new TransformStream({ transform(chunk, controller) { - const processedChunk = decoder - .decode(chunk) - .split('\n') - .filter((line) => line !== '') - .map(parseStreamPart) - .map((part) => part.value) - .join(''); - - controller.enqueue(encoder.encode(processedChunk)); + const text = decoder.decode(chunk); + const lines = text.split('\n').filter(line => line.trim() !== ''); + + for (const line of lines) { + try { + const parsed = parseStreamPart(line); + if (parsed.type === 'text') { + controller.enqueue(encoder.encode(parsed.value)); + } + } catch (e) { + // Skip invalid JSON lines + console.warn('Failed to parse stream part:', line); + } + } }, }); const transformedStream = result.toAIStream().pipeThrough(transformStream); return new StreamingTextResponse(transformedStream); - } catch (error) { + } catch (error: unknown) { console.log(error); + if (error instanceof Error && error.message?.includes('API key')) { + throw new Response('Invalid or missing API key', { + status: 401, + statusText: 'Unauthorized' + }); + } + throw new Response(null, { status: 500, statusText: 'Internal Server Error',