From ecc44ebd1e0707db56af4e186cc71e4ead9bbced Mon Sep 17 00:00:00 2001 From: Takahiro Kikumoto Date: Tue, 18 Mar 2025 16:48:12 +0900 Subject: [PATCH] Enhance get_models method to include models with INFERENCE_PROFILE type - Updated the get_models method to fetch models that support both ON_DEMAND and INFERENCE_PROFILE inference types. - Added a helper method getInferenceProfileId to retrieve the inference profile ID for models with INFERENCE_PROFILE type. - This change ensures that models with different inference types are correctly listed and available for use. --- .../providers/aws_bedrock_claude_pipeline.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/pipelines/providers/aws_bedrock_claude_pipeline.py b/examples/pipelines/providers/aws_bedrock_claude_pipeline.py index f347d77..245a046 100644 --- a/examples/pipelines/providers/aws_bedrock_claude_pipeline.py +++ b/examples/pipelines/providers/aws_bedrock_claude_pipeline.py @@ -100,14 +100,18 @@ class Pipeline: def get_models(self): try: - response = self.bedrock.list_foundation_models(byProvider='Anthropic', byInferenceType='ON_DEMAND') - return [ - { - "id": model["modelId"], - "name": model["modelName"], - } - for model in response["modelSummaries"] - ] + res = [] + response = self.bedrock.list_foundation_models(byProvider='Anthropic') + for model in response['modelSummaries']: + inference_types = model.get('inferenceTypesSupported', []) + if "ON_DEMAND" in inference_types: + res.append({'id': model['modelId'], 'name': model['modelName']}) + elif "INFERENCE_PROFILE" in inference_types: + inferenceProfileId = self.getInferenceProfileId(model['modelArn']) + if inferenceProfileId: + res.append({'id': inferenceProfileId, 'name': model['modelName']}) + + return res except Exception as e: print(f"Error: {e}") return [ @@ -117,6 +121,14 @@ class Pipeline: }, ] + def getInferenceProfileId(self, modelArn: str) -> str: + response = self.bedrock.list_inference_profiles() + for profile in response.get('inferenceProfileSummaries', []): + for model in profile.get('models', []): + if model.get('modelArn') == modelArn: + return profile['inferenceProfileId'] + return None + def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: