This commit is contained in:
revital 2024-08-14 15:10:26 +03:00
commit 9838631336
4 changed files with 88 additions and 65 deletions

View File

@ -26,7 +26,8 @@ _workers_pattern = re.compile(
(?P<instance_type>[^:]+)
(:(?P<cloud_id>[^:/]+))?
$
""", re.VERBOSE
""",
re.VERBOSE,
)
MINUTE = 60.0
@ -42,14 +43,14 @@ class WorkerId:
self.prefix = match["prefix"]
self.name = match["name"]
self.instance_type = match["instance_type"]
self.cloud_id = match["cloud_id"] or ''
self.cloud_id = match["cloud_id"] or ""
class State(str, Enum):
STARTING = 'starting'
READY = 'ready'
RUNNING = 'running'
STOPPED = 'stopped'
STARTING = "starting"
READY = "ready"
RUNNING = "running"
STOPPED = "stopped"
@attr.s
@ -64,35 +65,31 @@ class ScalerConfig:
@classmethod
def from_config(cls, config):
return cls(
max_idle_time_min=config['hyper_params']['max_idle_time_min'],
polling_interval_time_min=config['hyper_params']['polling_interval_time_min'],
max_spin_up_time_min=config['hyper_params']['max_spin_up_time_min'],
workers_prefix=config['hyper_params']['workers_prefix'],
resource_configurations=config['configurations']['resource_configurations'],
queues=config['configurations']['queues'],
max_idle_time_min=config["hyper_params"]["max_idle_time_min"],
polling_interval_time_min=config["hyper_params"]["polling_interval_time_min"],
max_spin_up_time_min=config["hyper_params"]["max_spin_up_time_min"],
workers_prefix=config["hyper_params"]["workers_prefix"],
resource_configurations=config["configurations"]["resource_configurations"],
queues=config["configurations"]["queues"],
)
class AutoScaler(object):
def __init__(self, config, driver: CloudDriver, logger=None):
self.logger = logger or get_logger('auto_scaler')
self.logger = logger or get_logger("auto_scaler")
# Should be after we create logger
self.state = State.STARTING
self.driver = driver
self.logger.info('using %s driver', self.driver.kind())
self.logger.info("using %s driver", self.driver.kind())
self.driver.set_scaler(self)
self.resource_configurations = config.resource_configurations
self.queues = config.queues # queue name -> list of resources
self.resource_to_queue = {
item[0]: queue
for queue, resources in self.queues.items()
for item in resources
}
self.resource_to_queue = {item[0]: queue for queue, resources in self.queues.items() for item in resources}
if not self.sanity_check():
raise ValueError('health check failed')
raise ValueError("health check failed")
self.max_idle_time_min = float(config.max_idle_time_min)
self.polling_interval_time_min = float(config.polling_interval_time_min)
@ -144,17 +141,17 @@ class AutoScaler(object):
try:
self.supervisor()
except Exception as ex:
self.logger.exception('Error: %r, retrying in 15 seconds', ex)
self.logger.exception("Error: %r, retrying in 15 seconds", ex)
sleep(15)
def stop(self):
self.logger.info('stopping')
self.logger.info("stopping")
self._stop_event.set()
self.state = State.STOPPED
def ensure_queues(self):
# Verify the requested queues exist and create those that doesn't exist
all_queues = {q.name for q in list(self.api_client.queues.get_all(only_fields=['name']))}
all_queues = {q.name for q in list(self.api_client.queues.get_all(only_fields=["name"]))}
missing_queues = set(self.queues) - all_queues
for q in missing_queues:
self.logger.info("Creating queue %r", q)
@ -163,7 +160,7 @@ class AutoScaler(object):
def queue_mapping(self):
id_to_name = {}
name_to_id = {}
for queue in self.api_client.queues.get_all(only_fields=['id', 'name']):
for queue in self.api_client.queues.get_all(only_fields=["id", "name"]):
id_to_name[queue.id] = queue.name
name_to_id[queue.name] = queue.id
@ -177,14 +174,14 @@ class AutoScaler(object):
if wid.prefix == self.workers_prefix:
workers.append(worker)
except ValueError:
self.logger.info('ignoring unknown worker: %r', worker.id)
self.logger.info("ignoring unknown worker: %r", worker.id)
return workers
def stale_workers(self, spun_workers):
now = time()
for worker_id, (resource, spin_time) in list(spun_workers.items()):
if now - spin_time > self.max_spin_up_time_min * MINUTE:
self.logger.info('Stuck spun instance %s of type %s', worker_id, resource)
self.logger.info("Stuck spun instance %s of type %s", worker_id, resource)
yield worker_id
def extra_allocations(self):
@ -192,7 +189,7 @@ class AutoScaler(object):
return []
def gen_worker_prefix(self, resource, resource_conf):
return '{workers_prefix}:{worker_type}:{instance_type}'.format(
return "{workers_prefix}:{worker_type}:{instance_type}".format(
workers_prefix=self.workers_prefix,
worker_type=resource,
instance_type=resource_conf["instance_type"],
@ -202,7 +199,7 @@ class AutoScaler(object):
self.logger.info("Checking if worker %r is still idle", worker_id)
for worker in self.api_client.workers.get_all():
if worker.id == worker_id:
return getattr(worker, 'task', None) is None
return getattr(worker, "task", None) is None
return True
def supervisor(self):
@ -237,7 +234,7 @@ class AutoScaler(object):
if worker.id not in previous_workers:
if not spun_workers.pop(worker.id, None):
if worker.id not in unknown_workers:
self.logger.info('Removed unknown worker from spun_workers: %s', worker.id)
self.logger.info("Removed unknown worker from spun_workers: %s", worker.id)
unknown_workers.append(worker.id)
else:
previous_workers.add(worker.id)
@ -245,15 +242,15 @@ class AutoScaler(object):
for worker_id in self.stale_workers(spun_workers):
out = spun_workers.pop(worker_id, None)
if out is None:
self.logger.warning('Ignoring unknown stale worker: %r', worker_id)
self.logger.warning("Ignoring unknown stale worker: %r", worker_id)
continue
resource = out[0]
try:
self.logger.info('Spinning down stuck worker: %r', worker_id)
self.logger.info("Spinning down stuck worker: %r", worker_id)
self.driver.spin_down_worker(WorkerId(worker_id).cloud_id)
up_machines[resource] -= 1
except Exception as err:
self.logger.info('Cannot spin down %r: %r', worker_id, err)
self.logger.info("Cannot spin down %r: %r", worker_id, err)
self.update_idle_workers(all_workers, idle_workers)
required_idle_resources = [] # idle resources we'll need to keep running
@ -289,13 +286,12 @@ class AutoScaler(object):
break
# check if we can add instances to `resource`
currently_running_workers = len(
[worker for worker in all_workers if WorkerId(worker.id).name == resource])
[worker for worker in all_workers if WorkerId(worker.id).name == resource]
)
spun_up_workers = sum(1 for r, _ in spun_workers.values() if r == resource)
max_allowed = int(max_instances) - currently_running_workers - spun_up_workers
if max_allowed > 0:
spin_up_resources.extend(
[resource] * min(spin_up_count, max_allowed)
)
spin_up_resources.extend([resource] * min(spin_up_count, max_allowed))
allocate_new_resources.extend(spin_up_resources)
# Now we actually spin the new machines
@ -307,16 +303,20 @@ class AutoScaler(object):
resource = WorkerId(worker_id).name
queue = self.resource_to_queue[resource]
suffix = ', task_id={!r}'.format(task_id) if task_id else ''
suffix = ", task_id={!r}".format(task_id) if task_id else ""
self.logger.info(
'Spinning new instance resource=%r, prefix=%r, queue=%r%s',
resource, self.workers_prefix, queue, suffix)
"Spinning new instance resource=%r, prefix=%r, queue=%r%s",
resource,
self.workers_prefix,
queue,
suffix,
)
resource_conf = self.resource_configurations[resource]
worker_prefix = self.gen_worker_prefix(resource, resource_conf)
instance_id = self.driver.spin_up_worker(resource_conf, worker_prefix, queue, task_id=task_id)
self.monitor_startup(instance_id)
worker_id = '{}:{}'.format(worker_prefix, instance_id)
self.logger.info('New instance ID: %s', instance_id)
worker_id = "{}:{}".format(worker_prefix, instance_id)
self.logger.info("New instance ID: %s", instance_id)
spun_workers[worker_id] = (resource, time())
up_machines[resource] += 1
except Exception as ex:
@ -353,7 +353,7 @@ class AutoScaler(object):
return
for worker in all_workers:
task = getattr(worker, 'task', None)
task = getattr(worker, "task", None)
if not task:
if worker.id not in idle_workers:
resource_name = WorkerId(worker.id).name
@ -366,9 +366,9 @@ class AutoScaler(object):
return not self._stop_event.is_set()
def report_app_stats(self, logger, queue_id_to_name, up_machines, idle_workers):
self.logger.info('resources: %r', self.resource_to_queue)
self.logger.info('idle worker: %r', idle_workers)
self.logger.info('up machines: %r', up_machines)
self.logger.info("resources: %r", self.resource_to_queue)
self.logger.info("idle worker: %r", idle_workers)
self.logger.info("up machines: %r", up_machines)
# Using property for state to log state change
@property
@ -377,11 +377,11 @@ class AutoScaler(object):
@state.setter
def state(self, value):
prev = getattr(self, '_state', None)
prev = getattr(self, "_state", None)
if prev:
self.logger.info('state change: %s -> %s', prev, value)
self.logger.info("state change: %s -> %s", prev, value)
else:
self.logger.info('initial state: %s', value)
self.logger.info("initial state: %s", value)
self._state = value
def monitor_startup(self, instance_id):
@ -396,15 +396,15 @@ class AutoScaler(object):
# TODO: Find a cross cloud way to get incremental logs
last_lnum = 0
while time() - start <= self.max_spin_up_time_min * MINUTE:
self.logger.info('getting startup logs for %r', instance_id)
self.logger.info("getting startup logs for %r", instance_id)
data = self.driver.console_log(instance_id)
lines = data.splitlines()
if not lines:
self.logger.info('not startup logs for %r', instance_id)
self.logger.info("not startup logs for %r", instance_id)
else:
last_lnum, lines = latest_lines(lines, last_lnum)
for line in lines:
self.logger.info('%r STARTUP LOG: %s', instance_id, line)
self.logger.info("%r STARTUP LOG: %s", instance_id, line)
sleep(MINUTE)
@ -437,9 +437,9 @@ def has_duplicate_resource(queues: dict):
def worker_last_time(worker):
"""Last time we heard from a worker. Current time if we can't find"""
time_attrs = [
'register_time',
'last_activity_time',
'last_report_time',
"register_time",
"last_activity_time",
"last_report_time",
]
times = [getattr(worker, attr).timestamp() for attr in time_attrs if getattr(worker, attr)]
return max(times) if times else time()

View File

@ -48,6 +48,7 @@ class CreateAndPopulate(object):
force_single_script_file=False, # type: bool
raise_on_missing_entries=False, # type: bool
verbose=False, # type: bool
binary=None # type: Optional[str]
):
# type: (...) -> None
"""
@ -90,6 +91,7 @@ class CreateAndPopulate(object):
:param force_single_script_file: If True, do not auto-detect local repository
:param raise_on_missing_entries: If True, raise ValueError on missing entries when populating
:param verbose: If True, print verbose logging
:param binary: Binary used to launch the entry point
"""
if repo and len(urlparse(repo).scheme) <= 1 and not re.compile(self._VCS_SSH_REGEX).match(repo):
folder = repo
@ -136,6 +138,7 @@ class CreateAndPopulate(object):
self.force_single_script_file = bool(force_single_script_file)
self.raise_on_missing_entries = raise_on_missing_entries
self.verbose = verbose
self.binary = binary
def create_task(self, dry_run=False):
# type: (bool) -> Union[Task, Dict]
@ -148,6 +151,7 @@ class CreateAndPopulate(object):
local_entry_file = None
repo_info = None
stand_alone_script_outside_repo = False
entry_point = ""
# populate from local repository / script
if self.folder or (self.script and Path(self.script).is_file() and not self.repo):
self.folder = os.path.expandvars(os.path.expanduser(self.folder)) if self.folder else None
@ -222,7 +226,8 @@ class CreateAndPopulate(object):
# check if we have no repository and no requirements raise error
if self.raise_on_missing_entries and (not self.requirements_file and not self.packages) \
and not self.repo and (
not repo_info or not repo_info.script or not repo_info.script.get('repository')):
not repo_info or not repo_info.script or not repo_info.script.get('repository')) \
and (not entry_point or not entry_point.endswith(".sh")):
raise ValueError("Standalone script detected \'{}\', but no requirements provided".format(self.script))
if dry_run:
task = None
@ -266,10 +271,10 @@ class CreateAndPopulate(object):
task_state['script']['diff'] = repo_info.script['diff'] or ''
task_state['script']['working_dir'] = repo_info.script['working_dir']
task_state['script']['entry_point'] = repo_info.script['entry_point']
task_state['script']['binary'] = '/bin/bash' if (
task_state['script']['binary'] = self.binary or ('/bin/bash' if (
(repo_info.script['entry_point'] or '').lower().strip().endswith('.sh') and
not (repo_info.script['entry_point'] or '').lower().strip().startswith('-m ')) \
else repo_info.script['binary']
else repo_info.script['binary'])
task_state['script']['requirements'] = repo_info.script.get('requirements') or {}
if self.cwd:
cwd = self.cwd
@ -344,14 +349,20 @@ class CreateAndPopulate(object):
detailed_req_report=False,
force_single_script=True,
)
task_state['script']['binary'] = '/bin/bash' if (
task_state['script']['binary'] = self.binary or ('/bin/bash' if (
(repo_info.script['entry_point'] or '').lower().strip().endswith('.sh') and
not (repo_info.script['entry_point'] or '').lower().strip().startswith('-m ')) \
else repo_info.script['binary']
else repo_info.script['binary'])
task_state['script']['diff'] = repo_info.script['diff'] or ''
task_state['script']['entry_point'] = repo_info.script['entry_point']
if create_requirements:
task_state['script']['requirements'] = repo_info.script.get('requirements') or {}
else:
if self.binary:
task_state["script"]["binary"] = self.binary
elif entry_point and entry_point.lower().strip().endswith(".sh") and not \
entry_point.lower().strip().startswith("-m"):
task_state["script"]["binary"] = "/bin/bash"
else:
# standalone task
task_state['script']['entry_point'] = self.script if self.script else \

View File

@ -59,11 +59,20 @@ def setup_parser(parser):
type=str,
default=None,
help="Specify the entry point script for the remote execution. "
"Currently support .py .ipynb and .sh scripts (python, jupyter notebook, bash) "
"Currently supports .py .ipynb and .sh scripts (python, jupyter notebook, bash) "
"When used in tandem with --repo the script should be a relative path inside "
"the repository, for example: --script source/train.py "
"the repository, for example: --script source/train.py."
"When used with --folder it supports a direct path to a file inside the local "
"repository itself, for example: --script ~/project/source/train.py",
"repository itself, for example: --script ~/project/source/train.py. "
"To run a bash script, simply specify the path of that script; the script should "
"have the .sh extension, for example: --script init.sh"
)
parser.add_argument(
"--binary",
type=str,
default=None,
help="Binary used to launch the entry point. For example: '--binary python3', '--binary /bin/bash'."
"By default, the binary will be auto-detected."
)
parser.add_argument(
"--module",
@ -186,6 +195,8 @@ def cli():
print("Importing offline session: {}".format(args.import_offline_session))
Task.import_offline_session(args.import_offline_session)
else:
if args.script and args.script.endswith(".sh") and not args.binary:
print("Detected shell script. Binary will be set to '/bin/bash'")
create_populate = CreateAndPopulate(
project_name=args.project,
task_name=args.name,
@ -206,6 +217,7 @@ def cli():
add_task_init_call=not args.skip_task_init,
raise_on_missing_entries=True,
verbose=True,
binary=args.binary
)
# verify args before creating the Task
create_populate.update_task_args(args.args)

View File

@ -219,9 +219,9 @@ class StorageManager(object):
.. note::
If we have a local file `~/folder/sub/file.ext` then
`StorageManager.upload_folder('~/folder/', 's3://bucket/')`
will create `s3://bucket/sub/file.ext`
If we have a local file ``\~/folder/sub/file.ext`` then
``StorageManager.upload_folder('\~/folder/', 's3://bucket/')``
will create ``s3://bucket/sub/file.ext``
:param str local_folder: Local folder to recursively upload
:param str remote_url: Target remote storage location, tree structure of `local_folder` will