diff --git a/app/lib/modules/llm/manager.ts b/app/lib/modules/llm/manager.ts index 88ae28c9..aec91905 100644 --- a/app/lib/modules/llm/manager.ts +++ b/app/lib/modules/llm/manager.ts @@ -118,12 +118,14 @@ export class LLMManager { return dynamicModels; }), ); + const staticModels = Array.from(this._providers.values()).flatMap((p) => p.staticModels || []); + const dynamicModelsFlat = dynamicModels.flat(); + const dynamicModelKeys = dynamicModelsFlat.map((d) => `${d.name}-${d.provider}`); + const filteredStaticModesl = staticModels.filter((m) => !dynamicModelKeys.includes(`${m.name}-${m.provider}`)); // Combine static and dynamic models - const modelList = [ - ...dynamicModels.flat(), - ...Array.from(this._providers.values()).flatMap((p) => p.staticModels || []), - ]; + const modelList = [...dynamicModelsFlat, ...filteredStaticModesl]; + modelList.sort((a, b) => a.name.localeCompare(b.name)); this._modelList = modelList; return modelList; @@ -178,8 +180,12 @@ export class LLMManager { logger.error(`Error getting dynamic models ${provider.name} :`, err); return []; }); + const dynamicModelsName = dynamicModels.map((d) => d.name); + const filteredStaticList = staticModels.filter((m) => !dynamicModelsName.includes(m.name)); + const modelList = [...dynamicModels, ...filteredStaticList]; + modelList.sort((a, b) => a.name.localeCompare(b.name)); - return [...dynamicModels, ...staticModels]; + return modelList; } getStaticModelListFromProvider(providerArg: BaseProvider) { const provider = this._providers.get(providerArg.name); diff --git a/app/lib/modules/llm/providers/google.ts b/app/lib/modules/llm/providers/google.ts index b69356c8..67043bad 100644 --- a/app/lib/modules/llm/providers/google.ts +++ b/app/lib/modules/llm/providers/google.ts @@ -14,7 +14,12 @@ export default class GoogleProvider extends BaseProvider { staticModels: ModelInfo[] = [ { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google', maxTokenAllowed: 8192 }, - { name: 'gemini-2.0-flash-thinking-exp-01-21', label: 'Gemini 2.0 Flash-thinking-exp-01-21', provider: 'Google', maxTokenAllowed: 65536 }, + { + name: 'gemini-2.0-flash-thinking-exp-01-21', + label: 'Gemini 2.0 Flash-thinking-exp-01-21', + provider: 'Google', + maxTokenAllowed: 65536, + }, { name: 'gemini-2.0-flash-exp', label: 'Gemini 2.0 Flash', provider: 'Google', maxTokenAllowed: 8192 }, { name: 'gemini-1.5-flash-002', label: 'Gemini 1.5 Flash-002', provider: 'Google', maxTokenAllowed: 8192 }, { name: 'gemini-1.5-flash-8b', label: 'Gemini 1.5 Flash-8b', provider: 'Google', maxTokenAllowed: 8192 }, @@ -23,6 +28,41 @@ export default class GoogleProvider extends BaseProvider { { name: 'gemini-exp-1206', label: 'Gemini exp-1206', provider: 'Google', maxTokenAllowed: 8192 }, ]; + async getDynamicModels( + apiKeys?: Record, + settings?: IProviderSetting, + serverEnv?: Record, + ): Promise { + const { apiKey } = this.getProviderBaseUrlAndKey({ + apiKeys, + providerSettings: settings, + serverEnv: serverEnv as any, + defaultBaseUrlKey: '', + defaultApiTokenKey: 'GOOGLE_GENERATIVE_AI_API_KEY', + }); + + if (!apiKey) { + throw `Missing Api Key configuration for ${this.name} provider`; + } + + const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`, { + headers: { + ['Content-Type']: 'application/json', + }, + }); + + const res = (await response.json()) as any; + + const data = res.models.filter((model: any) => model.outputTokenLimit > 8000); + + return data.map((m: any) => ({ + name: m.name.replace('models/', ''), + label: `${m.displayName} - context ${Math.floor((m.inputTokenLimit + m.outputTokenLimit) / 1000) + 'k'}`, + provider: this.name, + maxTokenAllowed: m.inputTokenLimit + m.outputTokenLimit || 8000, + })); + } + getModelInstance(options: { model: string; serverEnv: any; diff --git a/app/lib/modules/llm/providers/groq.ts b/app/lib/modules/llm/providers/groq.ts index 24d4f155..e9d2b0bd 100644 --- a/app/lib/modules/llm/providers/groq.ts +++ b/app/lib/modules/llm/providers/groq.ts @@ -19,9 +19,51 @@ export default class GroqProvider extends BaseProvider { { name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 }, { name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 }, { name: 'llama-3.3-70b-versatile', label: 'Llama 3.3 70b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 }, - { name: 'deepseek-r1-distill-llama-70b', label: 'Deepseek R1 Distill Llama 70b (Groq)', provider: 'Groq', maxTokenAllowed: 131072 }, + { + name: 'deepseek-r1-distill-llama-70b', + label: 'Deepseek R1 Distill Llama 70b (Groq)', + provider: 'Groq', + maxTokenAllowed: 131072, + }, ]; + async getDynamicModels( + apiKeys?: Record, + settings?: IProviderSetting, + serverEnv?: Record, + ): Promise { + const { apiKey } = this.getProviderBaseUrlAndKey({ + apiKeys, + providerSettings: settings, + serverEnv: serverEnv as any, + defaultBaseUrlKey: '', + defaultApiTokenKey: 'GROQ_API_KEY', + }); + + if (!apiKey) { + throw `Missing Api Key configuration for ${this.name} provider`; + } + + const response = await fetch(`https://api.groq.com/openai/v1/models`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + }); + + const res = (await response.json()) as any; + + const data = res.data.filter( + (model: any) => model.object === 'model' && model.active && model.context_window > 8000, + ); + + return data.map((m: any) => ({ + name: m.id, + label: `${m.id} - context ${m.context_window ? Math.floor(m.context_window / 1000) + 'k' : 'N/A'} [ by ${m.owned_by}]`, + provider: this.name, + maxTokenAllowed: m.context_window || 8000, + })); + } + getModelInstance(options: { model: string; serverEnv: Env; diff --git a/app/routes/api.models.ts b/app/routes/api.models.ts index f7512226..5fad834d 100644 --- a/app/routes/api.models.ts +++ b/app/routes/api.models.ts @@ -67,11 +67,11 @@ export async function loader({ const provider = llmManager.getProvider(params.provider); if (provider) { - const staticModels = provider.staticModels; - const dynamicModels = provider.getDynamicModels - ? await provider.getDynamicModels(apiKeys, providerSettings, context.cloudflare?.env) - : []; - modelList = [...staticModels, ...dynamicModels]; + modelList = await llmManager.getModelListFromProvider(provider, { + apiKeys, + providerSettings, + serverEnv: context.cloudflare?.env, + }); } } else { // Update all models