Black formatting

This commit is contained in:
allegroai 2024-08-12 20:16:07 +03:00
parent 35c8b9f7d3
commit bb8975065b

View File

@ -26,7 +26,8 @@ _workers_pattern = re.compile(
(?P<instance_type>[^:]+)
(:(?P<cloud_id>[^:/]+))?
$
""", re.VERBOSE
""",
re.VERBOSE,
)
MINUTE = 60.0
@ -42,14 +43,14 @@ class WorkerId:
self.prefix = match["prefix"]
self.name = match["name"]
self.instance_type = match["instance_type"]
self.cloud_id = match["cloud_id"] or ''
self.cloud_id = match["cloud_id"] or ""
class State(str, Enum):
STARTING = 'starting'
READY = 'ready'
RUNNING = 'running'
STOPPED = 'stopped'
STARTING = "starting"
READY = "ready"
RUNNING = "running"
STOPPED = "stopped"
@attr.s
@ -64,35 +65,31 @@ class ScalerConfig:
@classmethod
def from_config(cls, config):
return cls(
max_idle_time_min=config['hyper_params']['max_idle_time_min'],
polling_interval_time_min=config['hyper_params']['polling_interval_time_min'],
max_spin_up_time_min=config['hyper_params']['max_spin_up_time_min'],
workers_prefix=config['hyper_params']['workers_prefix'],
resource_configurations=config['configurations']['resource_configurations'],
queues=config['configurations']['queues'],
max_idle_time_min=config["hyper_params"]["max_idle_time_min"],
polling_interval_time_min=config["hyper_params"]["polling_interval_time_min"],
max_spin_up_time_min=config["hyper_params"]["max_spin_up_time_min"],
workers_prefix=config["hyper_params"]["workers_prefix"],
resource_configurations=config["configurations"]["resource_configurations"],
queues=config["configurations"]["queues"],
)
class AutoScaler(object):
def __init__(self, config, driver: CloudDriver, logger=None):
self.logger = logger or get_logger('auto_scaler')
self.logger = logger or get_logger("auto_scaler")
# Should be after we create logger
self.state = State.STARTING
self.driver = driver
self.logger.info('using %s driver', self.driver.kind())
self.logger.info("using %s driver", self.driver.kind())
self.driver.set_scaler(self)
self.resource_configurations = config.resource_configurations
self.queues = config.queues # queue name -> list of resources
self.resource_to_queue = {
item[0]: queue
for queue, resources in self.queues.items()
for item in resources
}
self.resource_to_queue = {item[0]: queue for queue, resources in self.queues.items() for item in resources}
if not self.sanity_check():
raise ValueError('health check failed')
raise ValueError("health check failed")
self.max_idle_time_min = float(config.max_idle_time_min)
self.polling_interval_time_min = float(config.polling_interval_time_min)
@ -144,17 +141,17 @@ class AutoScaler(object):
try:
self.supervisor()
except Exception as ex:
self.logger.exception('Error: %r, retrying in 15 seconds', ex)
self.logger.exception("Error: %r, retrying in 15 seconds", ex)
sleep(15)
def stop(self):
self.logger.info('stopping')
self.logger.info("stopping")
self._stop_event.set()
self.state = State.STOPPED
def ensure_queues(self):
# Verify the requested queues exist and create those that doesn't exist
all_queues = {q.name for q in list(self.api_client.queues.get_all(only_fields=['name']))}
all_queues = {q.name for q in list(self.api_client.queues.get_all(only_fields=["name"]))}
missing_queues = set(self.queues) - all_queues
for q in missing_queues:
self.logger.info("Creating queue %r", q)
@ -163,7 +160,7 @@ class AutoScaler(object):
def queue_mapping(self):
id_to_name = {}
name_to_id = {}
for queue in self.api_client.queues.get_all(only_fields=['id', 'name']):
for queue in self.api_client.queues.get_all(only_fields=["id", "name"]):
id_to_name[queue.id] = queue.name
name_to_id[queue.name] = queue.id
@ -177,14 +174,14 @@ class AutoScaler(object):
if wid.prefix == self.workers_prefix:
workers.append(worker)
except ValueError:
self.logger.info('ignoring unknown worker: %r', worker.id)
self.logger.info("ignoring unknown worker: %r", worker.id)
return workers
def stale_workers(self, spun_workers):
now = time()
for worker_id, (resource, spin_time) in list(spun_workers.items()):
if now - spin_time > self.max_spin_up_time_min * MINUTE:
self.logger.info('Stuck spun instance %s of type %s', worker_id, resource)
self.logger.info("Stuck spun instance %s of type %s", worker_id, resource)
yield worker_id
def extra_allocations(self):
@ -192,7 +189,7 @@ class AutoScaler(object):
return []
def gen_worker_prefix(self, resource, resource_conf):
return '{workers_prefix}:{worker_type}:{instance_type}'.format(
return "{workers_prefix}:{worker_type}:{instance_type}".format(
workers_prefix=self.workers_prefix,
worker_type=resource,
instance_type=resource_conf["instance_type"],
@ -202,7 +199,7 @@ class AutoScaler(object):
self.logger.info("Checking if worker %r is still idle", worker_id)
for worker in self.api_client.workers.get_all():
if worker.id == worker_id:
return getattr(worker, 'task', None) is None
return getattr(worker, "task", None) is None
return True
def supervisor(self):
@ -237,7 +234,7 @@ class AutoScaler(object):
if worker.id not in previous_workers:
if not spun_workers.pop(worker.id, None):
if worker.id not in unknown_workers:
self.logger.info('Removed unknown worker from spun_workers: %s', worker.id)
self.logger.info("Removed unknown worker from spun_workers: %s", worker.id)
unknown_workers.append(worker.id)
else:
previous_workers.add(worker.id)
@ -245,15 +242,15 @@ class AutoScaler(object):
for worker_id in self.stale_workers(spun_workers):
out = spun_workers.pop(worker_id, None)
if out is None:
self.logger.warning('Ignoring unknown stale worker: %r', worker_id)
self.logger.warning("Ignoring unknown stale worker: %r", worker_id)
continue
resource = out[0]
try:
self.logger.info('Spinning down stuck worker: %r', worker_id)
self.logger.info("Spinning down stuck worker: %r", worker_id)
self.driver.spin_down_worker(WorkerId(worker_id).cloud_id)
up_machines[resource] -= 1
except Exception as err:
self.logger.info('Cannot spin down %r: %r', worker_id, err)
self.logger.info("Cannot spin down %r: %r", worker_id, err)
self.update_idle_workers(all_workers, idle_workers)
required_idle_resources = [] # idle resources we'll need to keep running
@ -289,13 +286,12 @@ class AutoScaler(object):
break
# check if we can add instances to `resource`
currently_running_workers = len(
[worker for worker in all_workers if WorkerId(worker.id).name == resource])
[worker for worker in all_workers if WorkerId(worker.id).name == resource]
)
spun_up_workers = sum(1 for r, _ in spun_workers.values() if r == resource)
max_allowed = int(max_instances) - currently_running_workers - spun_up_workers
if max_allowed > 0:
spin_up_resources.extend(
[resource] * min(spin_up_count, max_allowed)
)
spin_up_resources.extend([resource] * min(spin_up_count, max_allowed))
allocate_new_resources.extend(spin_up_resources)
# Now we actually spin the new machines
@ -307,16 +303,20 @@ class AutoScaler(object):
resource = WorkerId(worker_id).name
queue = self.resource_to_queue[resource]
suffix = ', task_id={!r}'.format(task_id) if task_id else ''
suffix = ", task_id={!r}".format(task_id) if task_id else ""
self.logger.info(
'Spinning new instance resource=%r, prefix=%r, queue=%r%s',
resource, self.workers_prefix, queue, suffix)
"Spinning new instance resource=%r, prefix=%r, queue=%r%s",
resource,
self.workers_prefix,
queue,
suffix,
)
resource_conf = self.resource_configurations[resource]
worker_prefix = self.gen_worker_prefix(resource, resource_conf)
instance_id = self.driver.spin_up_worker(resource_conf, worker_prefix, queue, task_id=task_id)
self.monitor_startup(instance_id)
worker_id = '{}:{}'.format(worker_prefix, instance_id)
self.logger.info('New instance ID: %s', instance_id)
worker_id = "{}:{}".format(worker_prefix, instance_id)
self.logger.info("New instance ID: %s", instance_id)
spun_workers[worker_id] = (resource, time())
up_machines[resource] += 1
except Exception as ex:
@ -353,7 +353,7 @@ class AutoScaler(object):
return
for worker in all_workers:
task = getattr(worker, 'task', None)
task = getattr(worker, "task", None)
if not task:
if worker.id not in idle_workers:
resource_name = WorkerId(worker.id).name
@ -366,9 +366,9 @@ class AutoScaler(object):
return not self._stop_event.is_set()
def report_app_stats(self, logger, queue_id_to_name, up_machines, idle_workers):
self.logger.info('resources: %r', self.resource_to_queue)
self.logger.info('idle worker: %r', idle_workers)
self.logger.info('up machines: %r', up_machines)
self.logger.info("resources: %r", self.resource_to_queue)
self.logger.info("idle worker: %r", idle_workers)
self.logger.info("up machines: %r", up_machines)
# Using property for state to log state change
@property
@ -377,11 +377,11 @@ class AutoScaler(object):
@state.setter
def state(self, value):
prev = getattr(self, '_state', None)
prev = getattr(self, "_state", None)
if prev:
self.logger.info('state change: %s -> %s', prev, value)
self.logger.info("state change: %s -> %s", prev, value)
else:
self.logger.info('initial state: %s', value)
self.logger.info("initial state: %s", value)
self._state = value
def monitor_startup(self, instance_id):
@ -396,15 +396,15 @@ class AutoScaler(object):
# TODO: Find a cross cloud way to get incremental logs
last_lnum = 0
while time() - start <= self.max_spin_up_time_min * MINUTE:
self.logger.info('getting startup logs for %r', instance_id)
self.logger.info("getting startup logs for %r", instance_id)
data = self.driver.console_log(instance_id)
lines = data.splitlines()
if not lines:
self.logger.info('not startup logs for %r', instance_id)
self.logger.info("not startup logs for %r", instance_id)
else:
last_lnum, lines = latest_lines(lines, last_lnum)
for line in lines:
self.logger.info('%r STARTUP LOG: %s', instance_id, line)
self.logger.info("%r STARTUP LOG: %s", instance_id, line)
sleep(MINUTE)
@ -437,9 +437,9 @@ def has_duplicate_resource(queues: dict):
def worker_last_time(worker):
"""Last time we heard from a worker. Current time if we can't find"""
time_attrs = [
'register_time',
'last_activity_time',
'last_report_time',
"register_time",
"last_activity_time",
"last_report_time",
]
times = [getattr(worker, attr).timestamp() for attr in time_attrs if getattr(worker, attr)]
return max(times) if times else time()