diff --git a/clearml/automation/cloud_driver.py b/clearml/automation/cloud_driver.py index 0e069864..489518da 100644 --- a/clearml/automation/cloud_driver.py +++ b/clearml/automation/cloud_driver.py @@ -7,10 +7,10 @@ import attr from ..backend_api import Session from ..backend_api.session.defs import ENV_AUTH_TOKEN -env_git_user = 'CLEARML_AUTOSCALER_GIT_USER' -env_git_pass = 'CLEARML_AUTOSCALER_GIT_PASSWORD' +env_git_user = "CLEARML_AUTOSCALER_GIT_USER" +env_git_pass = "CLEARML_AUTOSCALER_GIT_PASSWORD" -bash_script_template = '''\ +bash_script_template = """\ #!/bin/bash set -x @@ -49,13 +49,13 @@ then fi shutdown -''' +""" -clearml_conf_template = '''\ +clearml_conf_template = """\ agent.git_user="{git_user}" agent.git_pass="{git_pass}" {extra_clearml_conf} -''' +""" @attr.s @@ -76,7 +76,7 @@ class CloudDriver(ABC): # Other extra_vm_bash_script = attr.ib() docker_image = attr.ib() - tags = attr.ib(default='') + tags = attr.ib(default="") session = attr.ib(default=None) def __attrs_post_init__(self): @@ -125,15 +125,13 @@ class CloudDriver(ABC): return bash_script_template.format( queue=queue_name, worker_prefix=worker_prefix, - - auth_token=self.auth_token or '', - access_key=self.access_key or '', + auth_token=self.auth_token or "", + access_key=self.access_key or "", api_server=self.api_server, clearml_conf=self.clearml_conf(), files_server=self.files_server, - secret_key=self.secret_key or '', + secret_key=self.secret_key or "", web_server=self.web_server, - bash_script=("export NVIDIA_VISIBLE_DEVICES=none; " if cpu_only else "") + self.extra_vm_bash_script, driver_extra=self.driver_bash_extra(task_id), docker="--docker '{}'".format(self.docker_image) if self.docker_image else "", @@ -142,8 +140,8 @@ class CloudDriver(ABC): def clearml_conf(self): # TODO: This need to be documented somewhere - git_user = environ.get(env_git_user) or self.git_user or '' - git_pass = environ.get(env_git_pass) or self.git_pass or '' + git_user = environ.get(env_git_user) or self.git_user or "" + git_pass = environ.get(env_git_pass) or self.git_pass or "" return clearml_conf_template.format( git_user=git_user, @@ -153,27 +151,27 @@ class CloudDriver(ABC): def driver_bash_extra(self, task_id): if not task_id: - return '' - return 'python -m clearml_agent --config-file ~/clearml.conf execute --id {}'.format(task_id) + return "" + return "python -m clearml_agent --config-file ~/clearml.conf execute --id {}".format(task_id) @classmethod def from_config(cls, config): session = Session() - hyper_params, configurations = config['hyper_params'], config['configurations'] + hyper_params, configurations = config["hyper_params"], config["configurations"] opts = { - 'git_user': hyper_params['git_user'], - 'git_pass': hyper_params['git_pass'], - 'extra_clearml_conf': configurations['extra_clearml_conf'], - 'api_server': session.get_api_server_host(), - 'web_server': session.get_app_server_host(), - 'files_server': session.get_files_server_host(), - 'access_key': session.access_key, - 'secret_key': session.secret_key, - 'auth_token': ENV_AUTH_TOKEN.get(), - 'extra_vm_bash_script': configurations['extra_vm_bash_script'], - 'docker_image': hyper_params['default_docker_image'], - 'tags': hyper_params.get('tags', ''), - 'session': session, + "git_user": hyper_params["git_user"], + "git_pass": hyper_params["git_pass"], + "extra_clearml_conf": configurations["extra_clearml_conf"], + "api_server": session.get_api_server_host(), + "web_server": session.get_app_server_host(), + "files_server": session.get_files_server_host(), + "access_key": session.access_key, + "secret_key": session.secret_key, + "auth_token": ENV_AUTH_TOKEN.get(), + "extra_vm_bash_script": configurations["extra_vm_bash_script"], + "docker_image": hyper_params["default_docker_image"], + "tags": hyper_params.get("tags", ""), + "session": session, } return cls(**opts) @@ -184,7 +182,7 @@ class CloudDriver(ABC): def logger(self): if self.scaler: return self.scaler.logger - return logging.getLogger('AWSDriver') + return logging.getLogger("AWSDriver") def parse_tags(s): @@ -197,10 +195,10 @@ def parse_tags(s): return [] tags = [] - for kv in s.split(','): - if '=' not in kv: + for kv in s.split(","): + if "=" not in kv: raise ValueError(kv) - key, value = [v.strip() for v in kv.split('=', 1)] + key, value = [v.strip() for v in kv.split("=", 1)] if not key or not value: raise ValueError(kv) tags.append((key, value))