picking right model

This commit is contained in:
Andrew Trokhymenko 2024-11-21 18:09:49 -05:00
parent 937ba7e61b
commit df94e665d6
4 changed files with 49 additions and 16 deletions

View File

@ -15,14 +15,23 @@ export function getAnthropicModel(apiKey: string, model: string) {
return anthropic(model); return anthropic(model);
} }
export function getOpenAILikeModel(baseURL:string,apiKey: string, model: string) {
export function getOpenAILikeModel(baseURL: string, apiKey: string, model: string) {
// console.log('OpenAILike config:', { baseURL, hasApiKey: !!apiKey, model });
const openai = createOpenAI({ const openai = createOpenAI({
baseURL, baseURL,
apiKey, apiKey,
}); });
// console.log('OpenAI client created:', !!openai);
return openai(model); const client = openai(model);
// console.log('OpenAI model client:', !!client);
return client;
// return {
// model: client,
// provider: 'OpenAILike' // Correctly identifying the actual provider
// };
} }
export function getOpenAIModel(apiKey: string, model: string) { export function getOpenAIModel(apiKey: string, model: string) {
const openai = createOpenAI({ const openai = createOpenAI({
apiKey, apiKey,
@ -74,7 +83,7 @@ export function getOllamaModel(baseURL: string, model: string) {
return Ollama; return Ollama;
} }
export function getDeepseekModel(apiKey: string, model: string){ export function getDeepseekModel(apiKey: string, model: string) {
const openai = createOpenAI({ const openai = createOpenAI({
baseURL: 'https://api.deepseek.com/beta', baseURL: 'https://api.deepseek.com/beta',
apiKey, apiKey,
@ -108,9 +117,15 @@ export function getXAIModel(apiKey: string, model: string) {
return openai(model); return openai(model);
} }
export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) { export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) {
const apiKey = getAPIKey(env, provider, apiKeys); let apiKey; // Declare first
const baseURL = getBaseURL(env, provider); let baseURL;
apiKey = getAPIKey(env, provider, apiKeys); // Then assign
baseURL = getBaseURL(env, provider);
// console.log('getModel inputs:', { provider, model, baseURL, hasApiKey: !!apiKey });
switch (provider) { switch (provider) {
case 'Anthropic': case 'Anthropic':
@ -126,11 +141,11 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re
case 'Google': case 'Google':
return getGoogleModel(apiKey, model); return getGoogleModel(apiKey, model);
case 'OpenAILike': case 'OpenAILike':
return getOpenAILikeModel(baseURL,apiKey, model); return getOpenAILikeModel(baseURL, apiKey, model);
case 'Deepseek': case 'Deepseek':
return getDeepseekModel(apiKey, model); return getDeepseekModel(apiKey, model);
case 'Mistral': case 'Mistral':
return getMistralModel(apiKey, model); return getMistralModel(apiKey, model);
case 'LMStudio': case 'LMStudio':
return getLMStudioModel(baseURL, model); return getLMStudioModel(baseURL, model);
case 'xAI': case 'xAI':
@ -138,4 +153,4 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re
default: default:
return getOllamaModel(baseURL, model); return getOllamaModel(baseURL, model);
} }
} }

View File

@ -52,6 +52,10 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
}) })
: textContent.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, ''); : textContent.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, '');
// console.log('Model from message:', model);
// console.log('Found in MODEL_LIST:', MODEL_LIST.find((m) => m.name === model));
// console.log('Current MODEL_LIST:', MODEL_LIST);
return { model, provider, content: cleanedContent }; return { model, provider, content: cleanedContent };
} }
@ -64,7 +68,7 @@ export function streamText(
let currentModel = DEFAULT_MODEL; let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER; let currentProvider = DEFAULT_PROVIDER;
console.log('StreamText:', JSON.stringify(messages)); // console.log('StreamText:', JSON.stringify(messages));
const processedMessages = messages.map((message) => { const processedMessages = messages.map((message) => {
if (message.role === 'user') { if (message.role === 'user') {
@ -82,11 +86,22 @@ export function streamText(
return message; // No changes for non-user messages return message; // No changes for non-user messages
}); });
return _streamText({ // console.log('Message content:', messages[0].content);
model: getModel(currentProvider, currentModel, env, apiKeys), // console.log('Extracted properties:', extractPropertiesFromMessage(messages[0]));
const llmClient = getModel(currentProvider, currentModel, env, apiKeys);
// console.log('LLM Client:', llmClient);
const llmConfig = {
...options,
model: llmClient, //getModel(currentProvider, currentModel, env, apiKeys),
provider: currentProvider,
system: getSystemPrompt(), system: getSystemPrompt(),
maxTokens: MAX_TOKENS, maxTokens: MAX_TOKENS,
messages: convertToCoreMessages(processedMessages), messages: convertToCoreMessages(processedMessages),
...options, };
});
// console.log('LLM Config:', llmConfig);
return _streamText(llmConfig);
} }

View File

@ -37,7 +37,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
model: string model: string
}>(); }>();
console.log('ChatAction:', JSON.stringify(messages)); // console.log('ChatAction:', JSON.stringify(messages));
const cookieHeader = request.headers.get("Cookie"); const cookieHeader = request.headers.get("Cookie");

View File

@ -32,6 +32,7 @@ const PROVIDER_LIST: ProviderInfo[] = [
name: 'OpenAILike', name: 'OpenAILike',
staticModels: [ staticModels: [
{ name: 'o1-mini', label: 'o1-mini', provider: 'OpenAILike' }, { name: 'o1-mini', label: 'o1-mini', provider: 'OpenAILike' },
{ name: 'gpt-4o-mini', label: 'GPT-4o Mini', provider: 'OpenAI' },
], ],
getDynamicModels: getOpenAILikeModels getDynamicModels: getOpenAILikeModels
}, },
@ -58,7 +59,9 @@ const PROVIDER_LIST: ProviderInfo[] = [
}, { }, {
name: 'Google', name: 'Google',
staticModels: [ staticModels: [
{ name: 'gemini-exp-1121', label: 'Gemini Experimental 1121', provider: 'Google' },
{ name: 'gemini-1.5-pro-002', label: 'Gemini 1.5 Pro 002', provider: 'Google' },
{ name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' }, { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' },
{ name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google' } { name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google' }
], ],