Merge pull request #347 from SujalXplores/fix/enhance-prompt

fix: enhance prompt "Invalid or missing provider" bad request error
This commit is contained in:
Eduard Ruzga 2024-11-20 16:40:15 +02:00 committed by GitHub
commit 9c657b962b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 32 deletions

View File

@ -1,4 +1,5 @@
import { useState } from 'react';
import type { ProviderInfo } from '~/types/model';
import { createScopedLogger } from '~/utils/logger';
const logger = createScopedLogger('usePromptEnhancement');
@ -13,54 +14,54 @@ export function usePromptEnhancer() {
};
const enhancePrompt = async (
input: string,
input: string,
setInput: (value: string) => void,
model: string,
provider: string,
apiKeys?: Record<string, string>
provider: ProviderInfo,
apiKeys?: Record<string, string>,
) => {
setEnhancingPrompt(true);
setPromptEnhanced(false);
const requestBody: any = {
message: input,
model,
provider,
};
if (apiKeys) {
requestBody.apiKeys = apiKeys;
}
const response = await fetch('/api/enhancer', {
method: 'POST',
body: JSON.stringify(requestBody),
});
const reader = response.body?.getReader();
const originalInput = input;
if (reader) {
const decoder = new TextDecoder();
let _input = '';
let _error;
try {
setInput('');
while (true) {
const { value, done } = await reader.read();
if (done) {
break;
}
_input += decoder.decode(value);
logger.trace('Set input', _input);
setInput(_input);
}
} catch (error) {
@ -70,10 +71,10 @@ export function usePromptEnhancer() {
if (_error) {
logger.error(_error);
}
setEnhancingPrompt(false);
setPromptEnhanced(true);
setTimeout(() => {
setInput(_input);
});

View File

@ -2,7 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent';
import type { StreamingOptions } from '~/lib/.server/llm/stream-text';
import type { ProviderInfo } from '~/types/model';
const encoder = new TextEncoder();
const decoder = new TextDecoder();
@ -12,25 +12,27 @@ export async function action(args: ActionFunctionArgs) {
}
async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message, model, provider, apiKeys } = await request.json<{
const { message, model, provider, apiKeys } = await request.json<{
message: string;
model: string;
provider: string;
provider: ProviderInfo;
apiKeys?: Record<string, string>;
}>();
// Validate 'model' and 'provider' fields
const { name: providerName } = provider;
// validate 'model' and 'provider' fields
if (!model || typeof model !== 'string') {
throw new Response('Invalid or missing model', {
status: 400,
statusText: 'Bad Request'
statusText: 'Bad Request',
});
}
if (!provider || typeof provider !== 'string') {
if (!providerName || typeof providerName !== 'string') {
throw new Response('Invalid or missing provider', {
status: 400,
statusText: 'Bad Request'
statusText: 'Bad Request',
});
}
@ -39,7 +41,9 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
[
{
role: 'user',
content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents`
content:
`[Model: ${model}]\n\n[Provider: ${providerName}]\n\n` +
stripIndents`
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
IMPORTANT: Only respond with the improved prompt and nothing else!
@ -52,23 +56,24 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
],
context.cloudflare.env,
undefined,
apiKeys
apiKeys,
);
const transformStream = new TransformStream({
transform(chunk, controller) {
const text = decoder.decode(chunk);
const lines = text.split('\n').filter(line => line.trim() !== '');
const lines = text.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) {
try {
const parsed = parseStreamPart(line);
if (parsed.type === 'text') {
controller.enqueue(encoder.encode(parsed.value));
}
} catch (e) {
// Skip invalid JSON lines
console.warn('Failed to parse stream part:', line);
// skip invalid JSON lines
console.warn('Failed to parse stream part:', line, e);
}
}
},
@ -83,7 +88,7 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
if (error instanceof Error && error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', {
status: 401,
statusText: 'Unauthorized'
statusText: 'Unauthorized',
});
}