Refactor AWS Bedrock Claude Pipeline to support Instance Profile and Task Role

- Updated `Valves` class to use `Optional[str]` for AWS credentials.
- Modified `__init__` method to initialize `valves` with environment variables.
- Added `update_pipelines` method to handle Bedrock client initialization and model fetching.
- Refactored `on_startup` and `on_valves_updated` methods to call `update_pipelines`.
- Improved error handling in `update_pipelines` and `get_models` methods.
This commit is contained in:
Takahiro Kikumoto 2025-03-18 16:33:44 +09:00
parent 51e267c10f
commit c1bbbe1165

View File

@ -12,7 +12,7 @@ import base64
import json import json
import logging import logging
from io import BytesIO from io import BytesIO
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator, Optional, Any
import boto3 import boto3
@ -26,9 +26,9 @@ from utils.pipelines.main import pop_system_message
class Pipeline: class Pipeline:
class Valves(BaseModel): class Valves(BaseModel):
AWS_ACCESS_KEY: str = "" AWS_ACCESS_KEY: Optional[str] = None
AWS_SECRET_KEY: str = "" AWS_SECRET_KEY: Optional[str] = None
AWS_REGION_NAME: str = "" AWS_REGION_NAME: Optional[str] = None
def __init__(self): def __init__(self):
self.type = "manifold" self.type = "manifold"
@ -47,21 +47,25 @@ class Pipeline:
} }
) )
self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, self.valves = self.Valves(
aws_secret_access_key=self.valves.AWS_SECRET_KEY, **{
service_name="bedrock", "AWS_ACCESS_KEY": os.getenv("AWS_ACCESS_KEY", ""),
region_name=self.valves.AWS_REGION_NAME) "AWS_SECRET_KEY": os.getenv("AWS_SECRET_KEY", ""),
self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, "AWS_REGION_NAME": os.getenv(
aws_secret_access_key=self.valves.AWS_SECRET_KEY, "AWS_REGION_NAME", os.getenv(
service_name="bedrock-runtime", "AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "")
region_name=self.valves.AWS_REGION_NAME) )
),
}
)
self.pipelines = self.get_models() self.update_pipelines()
async def on_startup(self): async def on_startup(self):
# This function is called when the server is started. # This function is called when the server is started.
print(f"on_startup:{__name__}") print(f"on_startup:{__name__}")
self.update_pipelines()
pass pass
async def on_shutdown(self): async def on_shutdown(self):
@ -72,40 +76,46 @@ class Pipeline:
async def on_valves_updated(self): async def on_valves_updated(self):
# This function is called when the valves are updated. # This function is called when the valves are updated.
print(f"on_valves_updated:{__name__}") print(f"on_valves_updated:{__name__}")
self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, self.update_pipelines()
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
service_name="bedrock",
region_name=self.valves.AWS_REGION_NAME)
self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
service_name="bedrock-runtime",
region_name=self.valves.AWS_REGION_NAME)
self.pipelines = self.get_models()
def pipelines(self) -> List[dict]: def update_pipelines(self) -> None:
return self.get_models() try:
self.bedrock = boto3.client(service_name="bedrock",
aws_access_key_id=self.valves.AWS_ACCESS_KEY,
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
region_name=self.valves.AWS_REGION_NAME)
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime",
aws_access_key_id=self.valves.AWS_ACCESS_KEY,
aws_secret_access_key=self.valves.AWS_SECRET_KEY,
region_name=self.valves.AWS_REGION_NAME)
self.pipelines = self.get_models()
except Exception as e:
print(f"Error: {e}")
self.pipelines = [
{
"id": "error",
"name": "Could not fetch models from Bedrock, please set up AWS Key/Secret or Instance/Task Role.",
},
]
def get_models(self): def get_models(self):
if self.valves.AWS_ACCESS_KEY and self.valves.AWS_SECRET_KEY: try:
try: response = self.bedrock.list_foundation_models(byProvider='Anthropic', byInferenceType='ON_DEMAND')
response = self.bedrock.list_foundation_models(byProvider='Anthropic', byInferenceType='ON_DEMAND') return [
return [ {
{ "id": model["modelId"],
"id": model["modelId"], "name": model["modelName"],
"name": model["modelName"], }
} for model in response["modelSummaries"]
for model in response["modelSummaries"] ]
] except Exception as e:
except Exception as e: print(f"Error: {e}")
print(f"Error: {e}") return [
return [ {
{ "id": "error",
"id": "error", "name": "Could not fetch models from Bedrock, please check permissoin.",
"name": "Could not fetch models from Bedrock, please update the Access/Secret Key in the valves.", },
}, ]
]
else:
return []
def pipe( def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict