diff --git a/app/lib/modules/llm/providers/vertex-ai.ts b/app/lib/modules/llm/providers/vertex-ai.ts index 949ca097..f7ee4dde 100644 --- a/app/lib/modules/llm/providers/vertex-ai.ts +++ b/app/lib/modules/llm/providers/vertex-ai.ts @@ -4,14 +4,12 @@ import type { IProviderSetting } from '~/types/model'; import type { LanguageModelV1, LanguageModelV1CallOptions } from 'ai'; -// ... rest of the code remains the same ... - export default class VertexAIProvider extends BaseProvider { name = 'VertexAI'; getApiKeyLink = 'https://console.cloud.google.com/'; config = { - apiTokenKey: 'GOOGLE_APPLICATION_CREDENTIALS', + apiTokenKey: 'GOOGLE_ACCESS_TOKEN', projectIdKey: 'GOOGLE_PROJECT_ID', locationKey: 'GOOGLE_LOCATION', }; @@ -57,16 +55,29 @@ export default class VertexAIProvider extends BaseProvider { }): LanguageModelV1 { const { model, serverEnv = {}, apiKeys, providerSettings } = options; - const { projectId, location } = this._getVertexAIConfig({ + // Get all required credentials using base provider's method + const { apiKey: accessToken, baseUrl: projectId } = this.getProviderBaseUrlAndKey({ apiKeys, providerSettings: providerSettings?.[this.name], serverEnv, + defaultBaseUrlKey: 'GOOGLE_PROJECT_ID', + defaultApiTokenKey: 'GOOGLE_ACCESS_TOKEN', }); - if (!projectId || !location) { - throw new Error(`Missing configuration for ${this.name} provider`); + if (!accessToken) { + throw new Error(`Missing API key for ${this.name} provider`); } + if (!projectId) { + throw new Error(`Missing project ID for ${this.name} provider`); + } + + // Get location from settings or default + const location = apiKeys?.GOOGLE_LOCATION || + providerSettings?.[this.name]?.location || + serverEnv?.GOOGLE_LOCATION || + 'us-central1'; + const instance: LanguageModelV1 = { specificationVersion: 'v1', provider: this.name, @@ -75,7 +86,7 @@ export default class VertexAIProvider extends BaseProvider { async doGenerate(options: LanguageModelV1CallOptions) { const messages = options.prompt.map((msg) => ({ - role: msg.role, + role: msg.role === 'system' ? 'user' : msg.role, parts: Array.isArray(msg.content) ? msg.content.map((part) => { if ('text' in part) { @@ -87,12 +98,12 @@ export default class VertexAIProvider extends BaseProvider { })); const endpoint = `https://${location}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${location}/publishers/google/models/${model}:generateContent`; - + const response = await fetch(endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json', - 'Authorization': `Bearer ${apiKeys?.GOOGLE_ACCESS_TOKEN}`, + Authorization: `Bearer ${accessToken}`, }, body: JSON.stringify({ contents: messages, @@ -105,11 +116,11 @@ export default class VertexAIProvider extends BaseProvider { }); if (!response.ok) { - const error = await response.json() as { error?: { message?: string } }; + const error = (await response.json()) as { error?: { message?: string } }; throw new Error(`Vertex AI API error: ${error.error?.message || 'Unknown error'}`); } - const data = await response.json() as { + const data = (await response.json()) as { candidates?: Array<{ content: { parts: Array<{ text: string }>; @@ -135,28 +146,33 @@ export default class VertexAIProvider extends BaseProvider { }; }, - async doStream(_options: LanguageModelV1CallOptions) { - throw new Error('Streaming not implemented for Vertex AI'); + async doStream(options: LanguageModelV1CallOptions) { + const response = await this.doGenerate(options); + return { + stream: new ReadableStream({ + start(controller) { + if (response.text) { + controller.enqueue({ + type: 'text-delta', + textDelta: response.text, + }); + } + controller.enqueue({ + type: 'finish', + finishReason: response.finishReason, + usage: response.usage, + }); + controller.close(); + }, + }), + rawCall: { + rawPrompt: options.prompt, + rawSettings: {}, + }, + }; }, }; return instance; } - - private _getVertexAIConfig({ - apiKeys, - providerSettings, - serverEnv, - }: { - apiKeys?: Record; - providerSettings?: IProviderSetting; - serverEnv: Record; - }) { - const projectId = apiKeys?.GOOGLE_PROJECT_ID || providerSettings?.projectId || serverEnv[this.config.projectIdKey]; - - const location = - apiKeys?.GOOGLE_LOCATION || providerSettings?.location || serverEnv[this.config.locationKey] || 'us-central1'; - - return { projectId, location }; - } }