diff --git a/app/lib/hooks/usePromptEnhancer.ts b/app/lib/hooks/usePromptEnhancer.ts index ee449992..6275ef37 100644 --- a/app/lib/hooks/usePromptEnhancer.ts +++ b/app/lib/hooks/usePromptEnhancer.ts @@ -1,4 +1,5 @@ import { useState } from 'react'; +import type { ProviderInfo } from '~/types/model'; import { createScopedLogger } from '~/utils/logger'; const logger = createScopedLogger('usePromptEnhancement'); @@ -13,54 +14,54 @@ export function usePromptEnhancer() { }; const enhancePrompt = async ( - input: string, + input: string, setInput: (value: string) => void, model: string, - provider: string, - apiKeys?: Record + provider: ProviderInfo, + apiKeys?: Record, ) => { setEnhancingPrompt(true); setPromptEnhanced(false); - + const requestBody: any = { message: input, model, provider, }; - + if (apiKeys) { requestBody.apiKeys = apiKeys; } - + const response = await fetch('/api/enhancer', { method: 'POST', body: JSON.stringify(requestBody), }); - + const reader = response.body?.getReader(); - + const originalInput = input; - + if (reader) { const decoder = new TextDecoder(); - + let _input = ''; let _error; - + try { setInput(''); - + while (true) { const { value, done } = await reader.read(); - + if (done) { break; } - + _input += decoder.decode(value); - + logger.trace('Set input', _input); - + setInput(_input); } } catch (error) { @@ -70,10 +71,10 @@ export function usePromptEnhancer() { if (_error) { logger.error(_error); } - + setEnhancingPrompt(false); setPromptEnhanced(true); - + setTimeout(() => { setInput(_input); }); diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts index 7040b890..77e6f2fd 100644 --- a/app/routes/api.enhancer.ts +++ b/app/routes/api.enhancer.ts @@ -2,7 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { StreamingTextResponse, parseStreamPart } from 'ai'; import { streamText } from '~/lib/.server/llm/stream-text'; import { stripIndents } from '~/utils/stripIndent'; -import type { StreamingOptions } from '~/lib/.server/llm/stream-text'; +import type { ProviderInfo } from '~/types/model'; const encoder = new TextEncoder(); const decoder = new TextDecoder(); @@ -12,25 +12,27 @@ export async function action(args: ActionFunctionArgs) { } async function enhancerAction({ context, request }: ActionFunctionArgs) { - const { message, model, provider, apiKeys } = await request.json<{ + const { message, model, provider, apiKeys } = await request.json<{ message: string; model: string; - provider: string; + provider: ProviderInfo; apiKeys?: Record; }>(); - // Validate 'model' and 'provider' fields + const { name: providerName } = provider; + + // validate 'model' and 'provider' fields if (!model || typeof model !== 'string') { throw new Response('Invalid or missing model', { status: 400, - statusText: 'Bad Request' + statusText: 'Bad Request', }); } - if (!provider || typeof provider !== 'string') { + if (!providerName || typeof providerName !== 'string') { throw new Response('Invalid or missing provider', { status: 400, - statusText: 'Bad Request' + statusText: 'Bad Request', }); } @@ -39,7 +41,9 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { [ { role: 'user', - content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents` + content: + `[Model: ${model}]\n\n[Provider: ${providerName}]\n\n` + + stripIndents` I want you to improve the user prompt that is wrapped in \`\` tags. IMPORTANT: Only respond with the improved prompt and nothing else! @@ -52,23 +56,24 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { ], context.cloudflare.env, undefined, - apiKeys + apiKeys, ); const transformStream = new TransformStream({ transform(chunk, controller) { const text = decoder.decode(chunk); - const lines = text.split('\n').filter(line => line.trim() !== ''); - + const lines = text.split('\n').filter((line) => line.trim() !== ''); + for (const line of lines) { try { const parsed = parseStreamPart(line); + if (parsed.type === 'text') { controller.enqueue(encoder.encode(parsed.value)); } } catch (e) { - // Skip invalid JSON lines - console.warn('Failed to parse stream part:', line); + // skip invalid JSON lines + console.warn('Failed to parse stream part:', line, e); } } }, @@ -83,7 +88,7 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { if (error instanceof Error && error.message?.includes('API key')) { throw new Response('Invalid or missing API key', { status: 401, - statusText: 'Unauthorized' + statusText: 'Unauthorized', }); }