From 972696450e21820a2e278d359dc9421ca18922d0 Mon Sep 17 00:00:00 2001 From: clearml <> Date: Thu, 3 Oct 2024 21:02:03 +0300 Subject: [PATCH] Black formatting --- clearml/automation/aws_driver.py | 98 ++++++++++++++------------------ 1 file changed, 42 insertions(+), 56 deletions(-) diff --git a/clearml/automation/aws_driver.py b/clearml/automation/aws_driver.py index 1f0b75de..c7def72c 100644 --- a/clearml/automation/aws_driver.py +++ b/clearml/automation/aws_driver.py @@ -16,34 +16,34 @@ try: Task.add_requirements("boto3") except ImportError as err: raise ImportError( - "AwsAutoScaler requires 'boto3' package, it was not found\n" - "install with: pip install boto3" + "AwsAutoScaler requires 'boto3' package, it was not found\n" "install with: pip install boto3" ) from err @attr.s class AWSDriver(CloudDriver): """AWS Driver""" - aws_access_key_id = attr.ib(validator=instance_of(str), default='') - aws_secret_access_key = attr.ib(validator=instance_of(str), default='') - aws_session_token = attr.ib(validator=instance_of(str), default='') - aws_region = attr.ib(validator=instance_of(str), default='') + + aws_access_key_id = attr.ib(validator=instance_of(str), default="") + aws_secret_access_key = attr.ib(validator=instance_of(str), default="") + aws_session_token = attr.ib(validator=instance_of(str), default="") + aws_region = attr.ib(validator=instance_of(str), default="") use_credentials_chain = attr.ib(validator=instance_of(bool), default=False) use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False) - iam_arn = attr.ib(validator=instance_of(str), default='') - iam_name = attr.ib(validator=instance_of(str), default='') + iam_arn = attr.ib(validator=instance_of(str), default="") + iam_name = attr.ib(validator=instance_of(str), default="") @classmethod def from_config(cls, config): obj = super().from_config(config) - obj.aws_access_key_id = config['hyper_params'].get('cloud_credentials_key') - obj.aws_secret_access_key = config['hyper_params'].get('cloud_credentials_secret') - obj.aws_session_token = config['hyper_params'].get('cloud_credentials_token') - obj.aws_region = config['hyper_params'].get('cloud_credentials_region') - obj.use_credentials_chain = config['hyper_params'].get('use_credentials_chain', False) - obj.use_iam_instance_profile = config['hyper_params'].get('use_iam_instance_profile', False) - obj.iam_arn = config['hyper_params'].get('iam_arn') - obj.iam_name = config['hyper_params'].get('iam_name') + obj.aws_access_key_id = config["hyper_params"].get("cloud_credentials_key") + obj.aws_secret_access_key = config["hyper_params"].get("cloud_credentials_secret") + obj.aws_session_token = config["hyper_params"].get("cloud_credentials_token") + obj.aws_region = config["hyper_params"].get("cloud_credentials_region") + obj.use_credentials_chain = config["hyper_params"].get("use_credentials_chain", False) + obj.use_iam_instance_profile = config["hyper_params"].get("use_iam_instance_profile", False) + obj.iam_arn = config["hyper_params"].get("iam_arn") + obj.iam_name = config["hyper_params"].get("iam_name") return obj def __attrs_post_init__(self): @@ -60,7 +60,7 @@ class AWSDriver(CloudDriver): launch_specification = ConfigFactory.from_dict( { "ImageId": resource_conf["ami_id"], - "Monitoring": {'Enabled': bool(resource_conf.get('enable_monitoring', False))}, + "Monitoring": {"Enabled": bool(resource_conf.get("enable_monitoring", False))}, "InstanceType": resource_conf["instance_type"], } ) @@ -70,9 +70,7 @@ class AWSDriver(CloudDriver): launch_specification["BlockDeviceMappings"] = [ { "DeviceName": resource_conf["ebs_device_name"], - "Ebs": { - "SnapshotId": resource_conf["ebs_snapshot_id"] - } + "Ebs": {"SnapshotId": resource_conf["ebs_snapshot_id"]}, } ] elif resource_conf.get("ebs_device_name"): @@ -81,8 +79,8 @@ class AWSDriver(CloudDriver): "DeviceName": resource_conf["ebs_device_name"], "Ebs": { "VolumeSize": resource_conf.get("ebs_volume_size", 80), - "VolumeType": resource_conf.get("ebs_volume_type", "gp3") - } + "VolumeType": resource_conf.get("ebs_volume_type", "gp3"), + }, } ] @@ -91,45 +89,33 @@ class AWSDriver(CloudDriver): elif resource_conf.get("availability_zone", None): launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]} else: - raise Exception('subnet_id or availability_zone must to be specified in the config') + raise Exception("subnet_id or availability_zone must to be specified in the config") if resource_conf.get("key_name", None): launch_specification["KeyName"] = resource_conf["key_name"] if resource_conf.get("security_group_ids", None): - launch_specification["SecurityGroupIds"] = resource_conf[ - "security_group_ids" - ] + launch_specification["SecurityGroupIds"] = resource_conf["security_group_ids"] # Adding iam role - you can have Arn OR Name, not both, Arn getting priority if self.iam_arn: launch_specification["IamInstanceProfile"] = { - 'Arn': self.iam_arn, + "Arn": self.iam_arn, } elif self.iam_name: - launch_specification["IamInstanceProfile"] = { - 'Name': self.iam_name - } + launch_specification["IamInstanceProfile"] = {"Name": self.iam_name} if resource_conf["is_spot"]: # Create a request for a spot instance in AWS - encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode( - "ascii" - ) + encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode("ascii") launch_specification["UserData"] = encoded_user_data - ConfigTree.merge_configs( - launch_specification, resource_conf.get("extra_configurations", {}) - ) + ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {})) - instances = ec2.request_spot_instances( - LaunchSpecification=launch_specification - ) + instances = ec2.request_spot_instances(LaunchSpecification=launch_specification) # Wait until spot request is fulfilled request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"] waiter = ec2.get_waiter("spot_instance_request_fulfilled") waiter.wait(SpotInstanceRequestIds=[request_id]) # Get the instance object for later use - response = ec2.describe_spot_instance_requests( - SpotInstanceRequestIds=[request_id] - ) + response = ec2.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id]) instance_id = response["SpotInstanceRequests"][0]["InstanceId"] else: @@ -140,9 +126,7 @@ class AWSDriver(CloudDriver): UserData=user_data, InstanceInitiatedShutdownBehavior="terminate", ) - ConfigTree.merge_configs( - launch_specification, resource_conf.get("extra_configurations", {}) - ) + ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {})) instances = ec2.run_instances(**launch_specification) @@ -165,30 +149,32 @@ class AWSDriver(CloudDriver): def creds(self): creds = { - 'region_name': self.aws_region or None, + "region_name": self.aws_region or None, } if not self.use_credentials_chain: - creds.update({ - 'aws_secret_access_key': self.aws_secret_access_key or None, - 'aws_access_key_id': self.aws_access_key_id or None, - 'aws_session_token': self.aws_session_token or None, - }) + creds.update( + { + "aws_secret_access_key": self.aws_secret_access_key or None, + "aws_access_key_id": self.aws_access_key_id or None, + "aws_session_token": self.aws_session_token or None, + } + ) return creds def instance_id_command(self): - return 'curl http://169.254.169.254/latest/meta-data/instance-id' + return "curl http://169.254.169.254/latest/meta-data/instance-id" def instance_type_key(self): - return 'instance_type' + return "instance_type" def kind(self): - return 'AWS' + return "AWS" def console_log(self, instance_id): ec2 = boto3.client("ec2", **self.creds()) try: out = ec2.get_console_output(InstanceId=instance_id) - return out.get('Output', '') + return out.get("Output", "") except ClientError as err: - return 'error: cannot get logs for {}:\n{}'.format(instance_id, err) + return "error: cannot get logs for {}:\n{}".format(instance_id, err)