From 31383bc9a6a5d6bff485d348a70c11f62694a7e6 Mon Sep 17 00:00:00 2001 From: morningxm Date: Sun, 16 Mar 2025 14:04:23 +0000 Subject: [PATCH] feat: custom order method for llm models --- app/lib/modules/llm/base-provider.ts | 4 + app/lib/modules/llm/manager.ts | 100 ++++++++++----------- app/lib/modules/llm/providers/anthropic.ts | 43 +++++++-- app/lib/modules/llm/types.ts | 1 + 4 files changed, 92 insertions(+), 56 deletions(-) diff --git a/app/lib/modules/llm/base-provider.ts b/app/lib/modules/llm/base-provider.ts index 9cb23403..b09b0140 100644 --- a/app/lib/modules/llm/base-provider.ts +++ b/app/lib/modules/llm/base-provider.ts @@ -115,6 +115,10 @@ export abstract class BaseProvider implements ProviderInfo { apiKeys?: Record; providerSettings?: Record; }): LanguageModelV1; + + sort(models: ModelInfo[]): ModelInfo[] { + return models.sort((a, b) => a.name.localeCompare(b.name)); + } } type OptionalApiKey = string | undefined; diff --git a/app/lib/modules/llm/manager.ts b/app/lib/modules/llm/manager.ts index aec91905..c96fbf18 100644 --- a/app/lib/modules/llm/manager.ts +++ b/app/lib/modules/llm/manager.ts @@ -1,6 +1,6 @@ import type { IProviderSetting } from '~/types/model'; import { BaseProvider } from './base-provider'; -import type { ModelInfo, ProviderInfo } from './types'; +import type { ModelInfo } from './types'; import * as providers from './registry'; import { createScopedLogger } from '~/utils/logger'; @@ -88,44 +88,44 @@ export class LLMManager { } // Get dynamic models from all providers that support them - const dynamicModels = await Promise.all( + const models = await Promise.all( Array.from(this._providers.values()) .filter((provider) => enabledProviders.includes(provider.name)) - .filter( - (provider): provider is BaseProvider & Required> => - !!provider.getDynamicModels, - ) .map(async (provider) => { - const cachedModels = provider.getModelsFromCache(options); + let dynamicModels: ModelInfo[] | null = []; - if (cachedModels) { - return cachedModels; + if (provider.getDynamicModels) { + dynamicModels = provider.getModelsFromCache(options); + + if (!dynamicModels) { + dynamicModels = await provider + .getDynamicModels(apiKeys, providerSettings?.[provider.name], serverEnv) + .then((models) => { + logger.info(`Caching ${models.length} dynamic models for ${provider.name}`); + provider.storeDynamicModels(options, models); + + return models; + }) + .catch((err) => { + logger.error(`Error getting dynamic models ${provider.name} :`, err); + return []; + }); + } } - const dynamicModels = await provider - .getDynamicModels(apiKeys, providerSettings?.[provider.name], serverEnv) - .then((models) => { - logger.info(`Caching ${models.length} dynamic models for ${provider.name}`); - provider.storeDynamicModels(options, models); + const models = provider.sort([ + ...dynamicModels, + ...provider.staticModels.filter( + (staticModel) => !dynamicModels.some((dynamicModel) => dynamicModel.name === staticModel.name), + ), + ]); - return models; - }) - .catch((err) => { - logger.error(`Error getting dynamic models ${provider.name} :`, err); - return []; - }); - - return dynamicModels; + return models; }), ); - const staticModels = Array.from(this._providers.values()).flatMap((p) => p.staticModels || []); - const dynamicModelsFlat = dynamicModels.flat(); - const dynamicModelKeys = dynamicModelsFlat.map((d) => `${d.name}-${d.provider}`); - const filteredStaticModesl = staticModels.filter((m) => !dynamicModelKeys.includes(`${m.name}-${m.provider}`)); // Combine static and dynamic models - const modelList = [...dynamicModelsFlat, ...filteredStaticModesl]; - modelList.sort((a, b) => a.name.localeCompare(b.name)); + const modelList = models.flat(); this._modelList = modelList; return modelList; @@ -155,35 +155,35 @@ export class LLMManager { const { apiKeys, providerSettings, serverEnv } = options; - const cachedModels = provider.getModelsFromCache({ + let dynamicModels = provider.getModelsFromCache({ apiKeys, providerSettings, serverEnv, }); - if (cachedModels) { - logger.info(`Found ${cachedModels.length} cached models for ${provider.name}`); - return [...cachedModels, ...staticModels]; + if (!dynamicModels) { + logger.info(`Getting dynamic models for ${provider.name}`); + + dynamicModels = await provider + .getDynamicModels?.(apiKeys, providerSettings?.[provider.name], serverEnv) + .then((models) => { + logger.info(`Got ${models.length} dynamic models for ${provider.name}`); + provider.storeDynamicModels(options, models); + + return models; + }) + .catch((err) => { + logger.error(`Error getting dynamic models ${provider.name} :`, err); + return []; + }); } - logger.info(`Getting dynamic models for ${provider.name}`); - - const dynamicModels = await provider - .getDynamicModels?.(apiKeys, providerSettings?.[provider.name], serverEnv) - .then((models) => { - logger.info(`Got ${models.length} dynamic models for ${provider.name}`); - provider.storeDynamicModels(options, models); - - return models; - }) - .catch((err) => { - logger.error(`Error getting dynamic models ${provider.name} :`, err); - return []; - }); - const dynamicModelsName = dynamicModels.map((d) => d.name); - const filteredStaticList = staticModels.filter((m) => !dynamicModelsName.includes(m.name)); - const modelList = [...dynamicModels, ...filteredStaticList]; - modelList.sort((a, b) => a.name.localeCompare(b.name)); + const modelList = provider.sort([ + ...dynamicModels, + ...staticModels.filter( + (staticModel) => !dynamicModels.some((dynamicModel) => dynamicModel.name === staticModel.name), + ), + ]); return modelList; } diff --git a/app/lib/modules/llm/providers/anthropic.ts b/app/lib/modules/llm/providers/anthropic.ts index 70f93c07..62886639 100644 --- a/app/lib/modules/llm/providers/anthropic.ts +++ b/app/lib/modules/llm/providers/anthropic.ts @@ -18,28 +18,51 @@ export default class AnthropicProvider extends BaseProvider { label: 'Claude 3.7 Sonnet', provider: 'Anthropic', maxTokenAllowed: 8000, + createdAt: '2025-02-24T00:00:00Z', }, { name: 'claude-3-5-sonnet-latest', - label: 'Claude 3.5 Sonnet (new)', + label: 'Claude 3.5 Sonnet (Latest)', provider: 'Anthropic', maxTokenAllowed: 8000, + createdAt: '2024-10-22T00:00:00Z', }, { name: 'claude-3-5-sonnet-20240620', - label: 'Claude 3.5 Sonnet (old)', + label: 'Claude 3.5 Sonnet (Old)', provider: 'Anthropic', maxTokenAllowed: 8000, + createdAt: '2024-06-20T00:00:00Z', }, { name: 'claude-3-5-haiku-latest', - label: 'Claude 3.5 Haiku (new)', + label: 'Claude 3.5 Haiku (Latest)', provider: 'Anthropic', maxTokenAllowed: 8000, + createdAt: '2024-10-22T00:00:00Z', + }, + + { + name: 'claude-3-opus-latest', + label: 'Claude 3 Opus', + provider: 'Anthropic', + maxTokenAllowed: 8000, + createdAt: '2024-02-29T00:00:00Z', + }, + { + name: 'claude-3-sonnet-20240229', + label: 'Claude 3 Sonnet', + provider: 'Anthropic', + maxTokenAllowed: 8000, + createdAt: '2024-02-29T00:00:00Z', + }, + { + name: 'claude-3-haiku-20240307', + label: 'Claude 3 Haiku', + provider: 'Anthropic', + maxTokenAllowed: 8000, + createdAt: '2024-03-07T00:00:00Z', }, - { name: 'claude-3-opus-latest', label: 'Claude 3 Opus', provider: 'Anthropic', maxTokenAllowed: 8000 }, - { name: 'claude-3-sonnet-20240229', label: 'Claude 3 Sonnet', provider: 'Anthropic', maxTokenAllowed: 8000 }, - { name: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku', provider: 'Anthropic', maxTokenAllowed: 8000 }, ]; async getDynamicModels( @@ -76,6 +99,7 @@ export default class AnthropicProvider extends BaseProvider { label: `${m.display_name}`, provider: this.name, maxTokenAllowed: 32000, + createdAt: m.created_at ?? '0000-00-00T00:00:00Z', })); } @@ -99,4 +123,11 @@ export default class AnthropicProvider extends BaseProvider { return anthropic(model); }; + + sort(models: ModelInfo[]): ModelInfo[] { + return models.sort((a, b) => { + const compare = b.createdAt!.localeCompare(a.createdAt!); + return compare === 0 ? b.name.localeCompare(a.name) : compare; + }); + } } diff --git a/app/lib/modules/llm/types.ts b/app/lib/modules/llm/types.ts index 421d6dfc..303202f1 100644 --- a/app/lib/modules/llm/types.ts +++ b/app/lib/modules/llm/types.ts @@ -6,6 +6,7 @@ export interface ModelInfo { label: string; provider: string; maxTokenAllowed: number; + createdAt?: string; } export interface ProviderInfo {