Merge pull request #347 from SujalXplores/fix/enhance-prompt

fix: enhance prompt "Invalid or missing provider" bad request error
This commit is contained in:
Eduard Ruzga 2024-11-20 16:40:15 +02:00 committed by GitHub
commit 9c657b962b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 32 deletions

View File

@ -1,4 +1,5 @@
import { useState } from 'react'; import { useState } from 'react';
import type { ProviderInfo } from '~/types/model';
import { createScopedLogger } from '~/utils/logger'; import { createScopedLogger } from '~/utils/logger';
const logger = createScopedLogger('usePromptEnhancement'); const logger = createScopedLogger('usePromptEnhancement');
@ -16,8 +17,8 @@ export function usePromptEnhancer() {
input: string, input: string,
setInput: (value: string) => void, setInput: (value: string) => void,
model: string, model: string,
provider: string, provider: ProviderInfo,
apiKeys?: Record<string, string> apiKeys?: Record<string, string>,
) => { ) => {
setEnhancingPrompt(true); setEnhancingPrompt(true);
setPromptEnhanced(false); setPromptEnhanced(false);

View File

@ -2,7 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { StreamingTextResponse, parseStreamPart } from 'ai'; import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text'; import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent'; import { stripIndents } from '~/utils/stripIndent';
import type { StreamingOptions } from '~/lib/.server/llm/stream-text'; import type { ProviderInfo } from '~/types/model';
const encoder = new TextEncoder(); const encoder = new TextEncoder();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
@ -15,22 +15,24 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message, model, provider, apiKeys } = await request.json<{ const { message, model, provider, apiKeys } = await request.json<{
message: string; message: string;
model: string; model: string;
provider: string; provider: ProviderInfo;
apiKeys?: Record<string, string>; apiKeys?: Record<string, string>;
}>(); }>();
// Validate 'model' and 'provider' fields const { name: providerName } = provider;
// validate 'model' and 'provider' fields
if (!model || typeof model !== 'string') { if (!model || typeof model !== 'string') {
throw new Response('Invalid or missing model', { throw new Response('Invalid or missing model', {
status: 400, status: 400,
statusText: 'Bad Request' statusText: 'Bad Request',
}); });
} }
if (!provider || typeof provider !== 'string') { if (!providerName || typeof providerName !== 'string') {
throw new Response('Invalid or missing provider', { throw new Response('Invalid or missing provider', {
status: 400, status: 400,
statusText: 'Bad Request' statusText: 'Bad Request',
}); });
} }
@ -39,7 +41,9 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
[ [
{ {
role: 'user', role: 'user',
content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents` content:
`[Model: ${model}]\n\n[Provider: ${providerName}]\n\n` +
stripIndents`
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags. I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
IMPORTANT: Only respond with the improved prompt and nothing else! IMPORTANT: Only respond with the improved prompt and nothing else!
@ -52,23 +56,24 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
], ],
context.cloudflare.env, context.cloudflare.env,
undefined, undefined,
apiKeys apiKeys,
); );
const transformStream = new TransformStream({ const transformStream = new TransformStream({
transform(chunk, controller) { transform(chunk, controller) {
const text = decoder.decode(chunk); const text = decoder.decode(chunk);
const lines = text.split('\n').filter(line => line.trim() !== ''); const lines = text.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) { for (const line of lines) {
try { try {
const parsed = parseStreamPart(line); const parsed = parseStreamPart(line);
if (parsed.type === 'text') { if (parsed.type === 'text') {
controller.enqueue(encoder.encode(parsed.value)); controller.enqueue(encoder.encode(parsed.value));
} }
} catch (e) { } catch (e) {
// Skip invalid JSON lines // skip invalid JSON lines
console.warn('Failed to parse stream part:', line); console.warn('Failed to parse stream part:', line, e);
} }
} }
}, },
@ -83,7 +88,7 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
if (error instanceof Error && error.message?.includes('API key')) { if (error instanceof Error && error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', { throw new Response('Invalid or missing API key', {
status: 401, status: 401,
statusText: 'Unauthorized' statusText: 'Unauthorized',
}); });
} }