Merge pull request #347 from SujalXplores/fix/enhance-prompt

fix: enhance prompt "Invalid or missing provider" bad request error
This commit is contained in:
Eduard Ruzga 2024-11-20 16:40:15 +02:00 committed by GitHub
commit 9c657b962b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 32 deletions

View File

@ -1,4 +1,5 @@
import { useState } from 'react'; import { useState } from 'react';
import type { ProviderInfo } from '~/types/model';
import { createScopedLogger } from '~/utils/logger'; import { createScopedLogger } from '~/utils/logger';
const logger = createScopedLogger('usePromptEnhancement'); const logger = createScopedLogger('usePromptEnhancement');
@ -13,54 +14,54 @@ export function usePromptEnhancer() {
}; };
const enhancePrompt = async ( const enhancePrompt = async (
input: string, input: string,
setInput: (value: string) => void, setInput: (value: string) => void,
model: string, model: string,
provider: string, provider: ProviderInfo,
apiKeys?: Record<string, string> apiKeys?: Record<string, string>,
) => { ) => {
setEnhancingPrompt(true); setEnhancingPrompt(true);
setPromptEnhanced(false); setPromptEnhanced(false);
const requestBody: any = { const requestBody: any = {
message: input, message: input,
model, model,
provider, provider,
}; };
if (apiKeys) { if (apiKeys) {
requestBody.apiKeys = apiKeys; requestBody.apiKeys = apiKeys;
} }
const response = await fetch('/api/enhancer', { const response = await fetch('/api/enhancer', {
method: 'POST', method: 'POST',
body: JSON.stringify(requestBody), body: JSON.stringify(requestBody),
}); });
const reader = response.body?.getReader(); const reader = response.body?.getReader();
const originalInput = input; const originalInput = input;
if (reader) { if (reader) {
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let _input = ''; let _input = '';
let _error; let _error;
try { try {
setInput(''); setInput('');
while (true) { while (true) {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
if (done) { if (done) {
break; break;
} }
_input += decoder.decode(value); _input += decoder.decode(value);
logger.trace('Set input', _input); logger.trace('Set input', _input);
setInput(_input); setInput(_input);
} }
} catch (error) { } catch (error) {
@ -70,10 +71,10 @@ export function usePromptEnhancer() {
if (_error) { if (_error) {
logger.error(_error); logger.error(_error);
} }
setEnhancingPrompt(false); setEnhancingPrompt(false);
setPromptEnhanced(true); setPromptEnhanced(true);
setTimeout(() => { setTimeout(() => {
setInput(_input); setInput(_input);
}); });

View File

@ -2,7 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { StreamingTextResponse, parseStreamPart } from 'ai'; import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text'; import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent'; import { stripIndents } from '~/utils/stripIndent';
import type { StreamingOptions } from '~/lib/.server/llm/stream-text'; import type { ProviderInfo } from '~/types/model';
const encoder = new TextEncoder(); const encoder = new TextEncoder();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
@ -12,25 +12,27 @@ export async function action(args: ActionFunctionArgs) {
} }
async function enhancerAction({ context, request }: ActionFunctionArgs) { async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message, model, provider, apiKeys } = await request.json<{ const { message, model, provider, apiKeys } = await request.json<{
message: string; message: string;
model: string; model: string;
provider: string; provider: ProviderInfo;
apiKeys?: Record<string, string>; apiKeys?: Record<string, string>;
}>(); }>();
// Validate 'model' and 'provider' fields const { name: providerName } = provider;
// validate 'model' and 'provider' fields
if (!model || typeof model !== 'string') { if (!model || typeof model !== 'string') {
throw new Response('Invalid or missing model', { throw new Response('Invalid or missing model', {
status: 400, status: 400,
statusText: 'Bad Request' statusText: 'Bad Request',
}); });
} }
if (!provider || typeof provider !== 'string') { if (!providerName || typeof providerName !== 'string') {
throw new Response('Invalid or missing provider', { throw new Response('Invalid or missing provider', {
status: 400, status: 400,
statusText: 'Bad Request' statusText: 'Bad Request',
}); });
} }
@ -39,7 +41,9 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
[ [
{ {
role: 'user', role: 'user',
content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents` content:
`[Model: ${model}]\n\n[Provider: ${providerName}]\n\n` +
stripIndents`
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags. I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
IMPORTANT: Only respond with the improved prompt and nothing else! IMPORTANT: Only respond with the improved prompt and nothing else!
@ -52,23 +56,24 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
], ],
context.cloudflare.env, context.cloudflare.env,
undefined, undefined,
apiKeys apiKeys,
); );
const transformStream = new TransformStream({ const transformStream = new TransformStream({
transform(chunk, controller) { transform(chunk, controller) {
const text = decoder.decode(chunk); const text = decoder.decode(chunk);
const lines = text.split('\n').filter(line => line.trim() !== ''); const lines = text.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) { for (const line of lines) {
try { try {
const parsed = parseStreamPart(line); const parsed = parseStreamPart(line);
if (parsed.type === 'text') { if (parsed.type === 'text') {
controller.enqueue(encoder.encode(parsed.value)); controller.enqueue(encoder.encode(parsed.value));
} }
} catch (e) { } catch (e) {
// Skip invalid JSON lines // skip invalid JSON lines
console.warn('Failed to parse stream part:', line); console.warn('Failed to parse stream part:', line, e);
} }
} }
}, },
@ -83,7 +88,7 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
if (error instanceof Error && error.message?.includes('API key')) { if (error instanceof Error && error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', { throw new Response('Invalid or missing API key', {
status: 401, status: 401,
statusText: 'Unauthorized' statusText: 'Unauthorized',
}); });
} }