Black formatting

This commit is contained in:
clearml 2024-10-06 19:23:08 +03:00
parent 972696450e
commit 2b6ab4edc8

View File

@ -7,10 +7,10 @@ import attr
from ..backend_api import Session from ..backend_api import Session
from ..backend_api.session.defs import ENV_AUTH_TOKEN from ..backend_api.session.defs import ENV_AUTH_TOKEN
env_git_user = 'CLEARML_AUTOSCALER_GIT_USER' env_git_user = "CLEARML_AUTOSCALER_GIT_USER"
env_git_pass = 'CLEARML_AUTOSCALER_GIT_PASSWORD' env_git_pass = "CLEARML_AUTOSCALER_GIT_PASSWORD"
bash_script_template = '''\ bash_script_template = """\
#!/bin/bash #!/bin/bash
set -x set -x
@ -49,13 +49,13 @@ then
fi fi
shutdown shutdown
''' """
clearml_conf_template = '''\ clearml_conf_template = """\
agent.git_user="{git_user}" agent.git_user="{git_user}"
agent.git_pass="{git_pass}" agent.git_pass="{git_pass}"
{extra_clearml_conf} {extra_clearml_conf}
''' """
@attr.s @attr.s
@ -76,7 +76,7 @@ class CloudDriver(ABC):
# Other # Other
extra_vm_bash_script = attr.ib() extra_vm_bash_script = attr.ib()
docker_image = attr.ib() docker_image = attr.ib()
tags = attr.ib(default='') tags = attr.ib(default="")
session = attr.ib(default=None) session = attr.ib(default=None)
def __attrs_post_init__(self): def __attrs_post_init__(self):
@ -125,15 +125,13 @@ class CloudDriver(ABC):
return bash_script_template.format( return bash_script_template.format(
queue=queue_name, queue=queue_name,
worker_prefix=worker_prefix, worker_prefix=worker_prefix,
auth_token=self.auth_token or "",
auth_token=self.auth_token or '', access_key=self.access_key or "",
access_key=self.access_key or '',
api_server=self.api_server, api_server=self.api_server,
clearml_conf=self.clearml_conf(), clearml_conf=self.clearml_conf(),
files_server=self.files_server, files_server=self.files_server,
secret_key=self.secret_key or '', secret_key=self.secret_key or "",
web_server=self.web_server, web_server=self.web_server,
bash_script=("export NVIDIA_VISIBLE_DEVICES=none; " if cpu_only else "") + self.extra_vm_bash_script, bash_script=("export NVIDIA_VISIBLE_DEVICES=none; " if cpu_only else "") + self.extra_vm_bash_script,
driver_extra=self.driver_bash_extra(task_id), driver_extra=self.driver_bash_extra(task_id),
docker="--docker '{}'".format(self.docker_image) if self.docker_image else "", docker="--docker '{}'".format(self.docker_image) if self.docker_image else "",
@ -142,8 +140,8 @@ class CloudDriver(ABC):
def clearml_conf(self): def clearml_conf(self):
# TODO: This need to be documented somewhere # TODO: This need to be documented somewhere
git_user = environ.get(env_git_user) or self.git_user or '' git_user = environ.get(env_git_user) or self.git_user or ""
git_pass = environ.get(env_git_pass) or self.git_pass or '' git_pass = environ.get(env_git_pass) or self.git_pass or ""
return clearml_conf_template.format( return clearml_conf_template.format(
git_user=git_user, git_user=git_user,
@ -153,27 +151,27 @@ class CloudDriver(ABC):
def driver_bash_extra(self, task_id): def driver_bash_extra(self, task_id):
if not task_id: if not task_id:
return '' return ""
return 'python -m clearml_agent --config-file ~/clearml.conf execute --id {}'.format(task_id) return "python -m clearml_agent --config-file ~/clearml.conf execute --id {}".format(task_id)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
session = Session() session = Session()
hyper_params, configurations = config['hyper_params'], config['configurations'] hyper_params, configurations = config["hyper_params"], config["configurations"]
opts = { opts = {
'git_user': hyper_params['git_user'], "git_user": hyper_params["git_user"],
'git_pass': hyper_params['git_pass'], "git_pass": hyper_params["git_pass"],
'extra_clearml_conf': configurations['extra_clearml_conf'], "extra_clearml_conf": configurations["extra_clearml_conf"],
'api_server': session.get_api_server_host(), "api_server": session.get_api_server_host(),
'web_server': session.get_app_server_host(), "web_server": session.get_app_server_host(),
'files_server': session.get_files_server_host(), "files_server": session.get_files_server_host(),
'access_key': session.access_key, "access_key": session.access_key,
'secret_key': session.secret_key, "secret_key": session.secret_key,
'auth_token': ENV_AUTH_TOKEN.get(), "auth_token": ENV_AUTH_TOKEN.get(),
'extra_vm_bash_script': configurations['extra_vm_bash_script'], "extra_vm_bash_script": configurations["extra_vm_bash_script"],
'docker_image': hyper_params['default_docker_image'], "docker_image": hyper_params["default_docker_image"],
'tags': hyper_params.get('tags', ''), "tags": hyper_params.get("tags", ""),
'session': session, "session": session,
} }
return cls(**opts) return cls(**opts)
@ -184,7 +182,7 @@ class CloudDriver(ABC):
def logger(self): def logger(self):
if self.scaler: if self.scaler:
return self.scaler.logger return self.scaler.logger
return logging.getLogger('AWSDriver') return logging.getLogger("AWSDriver")
def parse_tags(s): def parse_tags(s):
@ -197,10 +195,10 @@ def parse_tags(s):
return [] return []
tags = [] tags = []
for kv in s.split(','): for kv in s.split(","):
if '=' not in kv: if "=" not in kv:
raise ValueError(kv) raise ValueError(kv)
key, value = [v.strip() for v in kv.split('=', 1)] key, value = [v.strip() for v in kv.split("=", 1)]
if not key or not value: if not key or not value:
raise ValueError(kv) raise ValueError(kv)
tags.append((key, value)) tags.append((key, value))