Change in fetching of access keys from env

This commit is contained in:
Gauravacad99 2025-02-08 22:17:37 +05:30
parent ee1be90a64
commit d39cf87dff

View File

@ -4,14 +4,12 @@ import type { IProviderSetting } from '~/types/model';
import type { LanguageModelV1, LanguageModelV1CallOptions } from 'ai'; import type { LanguageModelV1, LanguageModelV1CallOptions } from 'ai';
// ... rest of the code remains the same ...
export default class VertexAIProvider extends BaseProvider { export default class VertexAIProvider extends BaseProvider {
name = 'VertexAI'; name = 'VertexAI';
getApiKeyLink = 'https://console.cloud.google.com/'; getApiKeyLink = 'https://console.cloud.google.com/';
config = { config = {
apiTokenKey: 'GOOGLE_APPLICATION_CREDENTIALS', apiTokenKey: 'GOOGLE_ACCESS_TOKEN',
projectIdKey: 'GOOGLE_PROJECT_ID', projectIdKey: 'GOOGLE_PROJECT_ID',
locationKey: 'GOOGLE_LOCATION', locationKey: 'GOOGLE_LOCATION',
}; };
@ -57,16 +55,29 @@ export default class VertexAIProvider extends BaseProvider {
}): LanguageModelV1 { }): LanguageModelV1 {
const { model, serverEnv = {}, apiKeys, providerSettings } = options; const { model, serverEnv = {}, apiKeys, providerSettings } = options;
const { projectId, location } = this._getVertexAIConfig({ // Get all required credentials using base provider's method
const { apiKey: accessToken, baseUrl: projectId } = this.getProviderBaseUrlAndKey({
apiKeys, apiKeys,
providerSettings: providerSettings?.[this.name], providerSettings: providerSettings?.[this.name],
serverEnv, serverEnv,
defaultBaseUrlKey: 'GOOGLE_PROJECT_ID',
defaultApiTokenKey: 'GOOGLE_ACCESS_TOKEN',
}); });
if (!projectId || !location) { if (!accessToken) {
throw new Error(`Missing configuration for ${this.name} provider`); throw new Error(`Missing API key for ${this.name} provider`);
} }
if (!projectId) {
throw new Error(`Missing project ID for ${this.name} provider`);
}
// Get location from settings or default
const location = apiKeys?.GOOGLE_LOCATION ||
providerSettings?.[this.name]?.location ||
serverEnv?.GOOGLE_LOCATION ||
'us-central1';
const instance: LanguageModelV1 = { const instance: LanguageModelV1 = {
specificationVersion: 'v1', specificationVersion: 'v1',
provider: this.name, provider: this.name,
@ -75,7 +86,7 @@ export default class VertexAIProvider extends BaseProvider {
async doGenerate(options: LanguageModelV1CallOptions) { async doGenerate(options: LanguageModelV1CallOptions) {
const messages = options.prompt.map((msg) => ({ const messages = options.prompt.map((msg) => ({
role: msg.role, role: msg.role === 'system' ? 'user' : msg.role,
parts: Array.isArray(msg.content) parts: Array.isArray(msg.content)
? msg.content.map((part) => { ? msg.content.map((part) => {
if ('text' in part) { if ('text' in part) {
@ -92,7 +103,7 @@ export default class VertexAIProvider extends BaseProvider {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': `Bearer ${apiKeys?.GOOGLE_ACCESS_TOKEN}`, Authorization: `Bearer ${accessToken}`,
}, },
body: JSON.stringify({ body: JSON.stringify({
contents: messages, contents: messages,
@ -105,11 +116,11 @@ export default class VertexAIProvider extends BaseProvider {
}); });
if (!response.ok) { if (!response.ok) {
const error = await response.json() as { error?: { message?: string } }; const error = (await response.json()) as { error?: { message?: string } };
throw new Error(`Vertex AI API error: ${error.error?.message || 'Unknown error'}`); throw new Error(`Vertex AI API error: ${error.error?.message || 'Unknown error'}`);
} }
const data = await response.json() as { const data = (await response.json()) as {
candidates?: Array<{ candidates?: Array<{
content: { content: {
parts: Array<{ text: string }>; parts: Array<{ text: string }>;
@ -135,28 +146,33 @@ export default class VertexAIProvider extends BaseProvider {
}; };
}, },
async doStream(_options: LanguageModelV1CallOptions) { async doStream(options: LanguageModelV1CallOptions) {
throw new Error('Streaming not implemented for Vertex AI'); const response = await this.doGenerate(options);
return {
stream: new ReadableStream({
start(controller) {
if (response.text) {
controller.enqueue({
type: 'text-delta',
textDelta: response.text,
});
}
controller.enqueue({
type: 'finish',
finishReason: response.finishReason,
usage: response.usage,
});
controller.close();
},
}),
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
};
}, },
}; };
return instance; return instance;
} }
private _getVertexAIConfig({
apiKeys,
providerSettings,
serverEnv,
}: {
apiKeys?: Record<string, string>;
providerSettings?: IProviderSetting;
serverEnv: Record<string, string>;
}) {
const projectId = apiKeys?.GOOGLE_PROJECT_ID || providerSettings?.projectId || serverEnv[this.config.projectIdKey];
const location =
apiKeys?.GOOGLE_LOCATION || providerSettings?.location || serverEnv[this.config.locationKey] || 'us-central1';
return { projectId, location };
}
} }