diff --git a/clearml/automation/auto_scaler.py b/clearml/automation/auto_scaler.py index 2218fb2a..5e7c4cb0 100644 --- a/clearml/automation/auto_scaler.py +++ b/clearml/automation/auto_scaler.py @@ -26,7 +26,8 @@ _workers_pattern = re.compile( (?P[^:]+) (:(?P[^:/]+))? $ - """, re.VERBOSE + """, + re.VERBOSE, ) MINUTE = 60.0 @@ -42,14 +43,14 @@ class WorkerId: self.prefix = match["prefix"] self.name = match["name"] self.instance_type = match["instance_type"] - self.cloud_id = match["cloud_id"] or '' + self.cloud_id = match["cloud_id"] or "" class State(str, Enum): - STARTING = 'starting' - READY = 'ready' - RUNNING = 'running' - STOPPED = 'stopped' + STARTING = "starting" + READY = "ready" + RUNNING = "running" + STOPPED = "stopped" @attr.s @@ -64,35 +65,31 @@ class ScalerConfig: @classmethod def from_config(cls, config): return cls( - max_idle_time_min=config['hyper_params']['max_idle_time_min'], - polling_interval_time_min=config['hyper_params']['polling_interval_time_min'], - max_spin_up_time_min=config['hyper_params']['max_spin_up_time_min'], - workers_prefix=config['hyper_params']['workers_prefix'], - resource_configurations=config['configurations']['resource_configurations'], - queues=config['configurations']['queues'], + max_idle_time_min=config["hyper_params"]["max_idle_time_min"], + polling_interval_time_min=config["hyper_params"]["polling_interval_time_min"], + max_spin_up_time_min=config["hyper_params"]["max_spin_up_time_min"], + workers_prefix=config["hyper_params"]["workers_prefix"], + resource_configurations=config["configurations"]["resource_configurations"], + queues=config["configurations"]["queues"], ) class AutoScaler(object): def __init__(self, config, driver: CloudDriver, logger=None): - self.logger = logger or get_logger('auto_scaler') + self.logger = logger or get_logger("auto_scaler") # Should be after we create logger self.state = State.STARTING self.driver = driver - self.logger.info('using %s driver', self.driver.kind()) + self.logger.info("using %s driver", self.driver.kind()) self.driver.set_scaler(self) self.resource_configurations = config.resource_configurations self.queues = config.queues # queue name -> list of resources - self.resource_to_queue = { - item[0]: queue - for queue, resources in self.queues.items() - for item in resources - } + self.resource_to_queue = {item[0]: queue for queue, resources in self.queues.items() for item in resources} if not self.sanity_check(): - raise ValueError('health check failed') + raise ValueError("health check failed") self.max_idle_time_min = float(config.max_idle_time_min) self.polling_interval_time_min = float(config.polling_interval_time_min) @@ -144,17 +141,17 @@ class AutoScaler(object): try: self.supervisor() except Exception as ex: - self.logger.exception('Error: %r, retrying in 15 seconds', ex) + self.logger.exception("Error: %r, retrying in 15 seconds", ex) sleep(15) def stop(self): - self.logger.info('stopping') + self.logger.info("stopping") self._stop_event.set() self.state = State.STOPPED def ensure_queues(self): # Verify the requested queues exist and create those that doesn't exist - all_queues = {q.name for q in list(self.api_client.queues.get_all(only_fields=['name']))} + all_queues = {q.name for q in list(self.api_client.queues.get_all(only_fields=["name"]))} missing_queues = set(self.queues) - all_queues for q in missing_queues: self.logger.info("Creating queue %r", q) @@ -163,7 +160,7 @@ class AutoScaler(object): def queue_mapping(self): id_to_name = {} name_to_id = {} - for queue in self.api_client.queues.get_all(only_fields=['id', 'name']): + for queue in self.api_client.queues.get_all(only_fields=["id", "name"]): id_to_name[queue.id] = queue.name name_to_id[queue.name] = queue.id @@ -177,14 +174,14 @@ class AutoScaler(object): if wid.prefix == self.workers_prefix: workers.append(worker) except ValueError: - self.logger.info('ignoring unknown worker: %r', worker.id) + self.logger.info("ignoring unknown worker: %r", worker.id) return workers def stale_workers(self, spun_workers): now = time() for worker_id, (resource, spin_time) in list(spun_workers.items()): if now - spin_time > self.max_spin_up_time_min * MINUTE: - self.logger.info('Stuck spun instance %s of type %s', worker_id, resource) + self.logger.info("Stuck spun instance %s of type %s", worker_id, resource) yield worker_id def extra_allocations(self): @@ -192,7 +189,7 @@ class AutoScaler(object): return [] def gen_worker_prefix(self, resource, resource_conf): - return '{workers_prefix}:{worker_type}:{instance_type}'.format( + return "{workers_prefix}:{worker_type}:{instance_type}".format( workers_prefix=self.workers_prefix, worker_type=resource, instance_type=resource_conf["instance_type"], @@ -202,7 +199,7 @@ class AutoScaler(object): self.logger.info("Checking if worker %r is still idle", worker_id) for worker in self.api_client.workers.get_all(): if worker.id == worker_id: - return getattr(worker, 'task', None) is None + return getattr(worker, "task", None) is None return True def supervisor(self): @@ -237,7 +234,7 @@ class AutoScaler(object): if worker.id not in previous_workers: if not spun_workers.pop(worker.id, None): if worker.id not in unknown_workers: - self.logger.info('Removed unknown worker from spun_workers: %s', worker.id) + self.logger.info("Removed unknown worker from spun_workers: %s", worker.id) unknown_workers.append(worker.id) else: previous_workers.add(worker.id) @@ -245,15 +242,15 @@ class AutoScaler(object): for worker_id in self.stale_workers(spun_workers): out = spun_workers.pop(worker_id, None) if out is None: - self.logger.warning('Ignoring unknown stale worker: %r', worker_id) + self.logger.warning("Ignoring unknown stale worker: %r", worker_id) continue resource = out[0] try: - self.logger.info('Spinning down stuck worker: %r', worker_id) + self.logger.info("Spinning down stuck worker: %r", worker_id) self.driver.spin_down_worker(WorkerId(worker_id).cloud_id) up_machines[resource] -= 1 except Exception as err: - self.logger.info('Cannot spin down %r: %r', worker_id, err) + self.logger.info("Cannot spin down %r: %r", worker_id, err) self.update_idle_workers(all_workers, idle_workers) required_idle_resources = [] # idle resources we'll need to keep running @@ -289,13 +286,12 @@ class AutoScaler(object): break # check if we can add instances to `resource` currently_running_workers = len( - [worker for worker in all_workers if WorkerId(worker.id).name == resource]) + [worker for worker in all_workers if WorkerId(worker.id).name == resource] + ) spun_up_workers = sum(1 for r, _ in spun_workers.values() if r == resource) max_allowed = int(max_instances) - currently_running_workers - spun_up_workers if max_allowed > 0: - spin_up_resources.extend( - [resource] * min(spin_up_count, max_allowed) - ) + spin_up_resources.extend([resource] * min(spin_up_count, max_allowed)) allocate_new_resources.extend(spin_up_resources) # Now we actually spin the new machines @@ -307,16 +303,20 @@ class AutoScaler(object): resource = WorkerId(worker_id).name queue = self.resource_to_queue[resource] - suffix = ', task_id={!r}'.format(task_id) if task_id else '' + suffix = ", task_id={!r}".format(task_id) if task_id else "" self.logger.info( - 'Spinning new instance resource=%r, prefix=%r, queue=%r%s', - resource, self.workers_prefix, queue, suffix) + "Spinning new instance resource=%r, prefix=%r, queue=%r%s", + resource, + self.workers_prefix, + queue, + suffix, + ) resource_conf = self.resource_configurations[resource] worker_prefix = self.gen_worker_prefix(resource, resource_conf) instance_id = self.driver.spin_up_worker(resource_conf, worker_prefix, queue, task_id=task_id) self.monitor_startup(instance_id) - worker_id = '{}:{}'.format(worker_prefix, instance_id) - self.logger.info('New instance ID: %s', instance_id) + worker_id = "{}:{}".format(worker_prefix, instance_id) + self.logger.info("New instance ID: %s", instance_id) spun_workers[worker_id] = (resource, time()) up_machines[resource] += 1 except Exception as ex: @@ -353,7 +353,7 @@ class AutoScaler(object): return for worker in all_workers: - task = getattr(worker, 'task', None) + task = getattr(worker, "task", None) if not task: if worker.id not in idle_workers: resource_name = WorkerId(worker.id).name @@ -366,9 +366,9 @@ class AutoScaler(object): return not self._stop_event.is_set() def report_app_stats(self, logger, queue_id_to_name, up_machines, idle_workers): - self.logger.info('resources: %r', self.resource_to_queue) - self.logger.info('idle worker: %r', idle_workers) - self.logger.info('up machines: %r', up_machines) + self.logger.info("resources: %r", self.resource_to_queue) + self.logger.info("idle worker: %r", idle_workers) + self.logger.info("up machines: %r", up_machines) # Using property for state to log state change @property @@ -377,11 +377,11 @@ class AutoScaler(object): @state.setter def state(self, value): - prev = getattr(self, '_state', None) + prev = getattr(self, "_state", None) if prev: - self.logger.info('state change: %s -> %s', prev, value) + self.logger.info("state change: %s -> %s", prev, value) else: - self.logger.info('initial state: %s', value) + self.logger.info("initial state: %s", value) self._state = value def monitor_startup(self, instance_id): @@ -396,15 +396,15 @@ class AutoScaler(object): # TODO: Find a cross cloud way to get incremental logs last_lnum = 0 while time() - start <= self.max_spin_up_time_min * MINUTE: - self.logger.info('getting startup logs for %r', instance_id) + self.logger.info("getting startup logs for %r", instance_id) data = self.driver.console_log(instance_id) lines = data.splitlines() if not lines: - self.logger.info('not startup logs for %r', instance_id) + self.logger.info("not startup logs for %r", instance_id) else: last_lnum, lines = latest_lines(lines, last_lnum) for line in lines: - self.logger.info('%r STARTUP LOG: %s', instance_id, line) + self.logger.info("%r STARTUP LOG: %s", instance_id, line) sleep(MINUTE) @@ -437,9 +437,9 @@ def has_duplicate_resource(queues: dict): def worker_last_time(worker): """Last time we heard from a worker. Current time if we can't find""" time_attrs = [ - 'register_time', - 'last_activity_time', - 'last_report_time', + "register_time", + "last_activity_time", + "last_report_time", ] times = [getattr(worker, attr).timestamp() for attr in time_attrs if getattr(worker, attr)] return max(times) if times else time()