Add automation support including hyper-parameters optimization

This commit is contained in:
allegroai 2020-05-22 12:00:07 +03:00
parent b457b9aaad
commit 95105cbe6a
16 changed files with 1921 additions and 94 deletions

View File

@ -16,4 +16,4 @@ params['Example_Param'] = 1
task.connect(params)
# Print the value to demonstrate it is the value is set by the initiating task.
print ("Example_Param is", params['Example_Param'])
print("Example_Param is", params['Example_Param'])

View File

@ -75,6 +75,7 @@ try:
},
index=['falcon', 'dog', 'spider', 'fish']
)
df.index.name = 'id'
logger.report_table("test table pd", "PD with index", 1, table_plot=df)
# Report table - CSV from path

View File

@ -0,0 +1,3 @@
from .parameters import UniformParameterRange, DiscreteParameterRange, UniformIntegerParameterRange, ParameterSet
from .optimization import GridSearch, RandomSearch, HyperParameterOptimizer, Objective
from .job import TrainsJob

View File

@ -0,0 +1 @@
from .bandster import OptimizerBOHB

View File

@ -0,0 +1,240 @@
from time import sleep, time
from ..parameters import DiscreteParameterRange, UniformParameterRange, RandomSeed, UniformIntegerParameterRange
from ..optimization import Objective, SearchStrategy
from ...task import Task
try:
from hpbandster.core.worker import Worker
from hpbandster.optimizers import BOHB
import hpbandster.core.nameserver as hpns
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
Task.add_requirements('hpbandster')
except ImportError:
raise ValueError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
"install with: pip install hpbandster")
class TrainsBandsterWorker(Worker):
def __init__(self, *args, optimizer, base_task_id, queue_name, objective,
sleep_interval=0, budget_iteration_scale=1., **kwargs):
super(TrainsBandsterWorker, self).__init__(*args, **kwargs)
self.optimizer = optimizer
self.base_task_id = base_task_id
self.queue_name = queue_name
self.objective = objective
self.sleep_interval = sleep_interval
self.budget_iteration_scale = budget_iteration_scale
self._current_job = None
def compute(self, config, budget, **kwargs):
"""
Simple example for a compute function
The loss is just a the config + some noise (that decreases with the budget)
For dramatization, the function can sleep for a given interval to emphasizes
the speed ups achievable with parallel workers.
Args:
config: dictionary containing the sampled configurations by the optimizer
budget: (float) amount of time/epochs/etc. the model can use to train.
We assume budget is iteration, as time might not be stable from machine to machine.
Returns:
dictionary with mandatory fields:
'loss' (scalar)
'info' (dict)
"""
self._current_job = self.optimizer.helper_create_job(self.base_task_id, parameter_override=config)
self.optimizer._current_jobs.append(self._current_job)
self._current_job.launch(self.queue_name)
iteration_value = None
while not self._current_job.is_stopped():
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(self._current_job)
if iteration_value and iteration_value[0] >= self.budget_iteration_scale * budget:
self._current_job.abort()
break
sleep(self.sleep_interval)
result = {
# this is the a mandatory field to run hyperband
# remember: HpBandSter always minimizes!
'loss': float(self.objective.get_normalized_objective(self._current_job)*-1.),
# can be used for any user-defined information - also mandatory
'info': self._current_job.task_id()
}
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value))
self.optimizer._current_jobs.remove(self._current_job)
return result
class OptimizerBOHB(SearchStrategy, RandomSeed):
def __init__(self, base_task_id, hyper_parameters, objective_metric,
execution_queue, num_concurrent_workers, min_iteration_per_job, max_iteration_per_job, total_max_jobs,
pool_period_min=2.0, max_job_execution_minutes=None, **bohb_kargs):
"""
Initialize a search strategy optimizer
:param str base_task_id: Task ID (str)
:param list hyper_parameters: list of Parameter objects to optimize over
:param Objective objective_metric: Objective metric to maximize / minimize
:param str execution_queue: execution queue to use for launching Tasks (experiments).
:param int num_concurrent_workers: Limit number of concurrent running machines
:param float min_iteration_per_job: minimum number of iterations for a job to run.
:param int max_iteration_per_job: number of iteration per job
:param int total_max_jobs: total maximum job for the optimization process. Default None, unlimited
:param float pool_period_min: time in minutes between two consecutive pools
:param float max_job_execution_minutes: maximum time per single job in minutes, if exceeded job is aborted
:param ** bohb_kargs: arguments passed directly yo the BOHB object
"""
super(OptimizerBOHB, self).__init__(
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min, max_job_execution_minutes=max_job_execution_minutes,
total_max_jobs=total_max_jobs)
self._max_iteration_per_job = max_iteration_per_job
self._min_iteration_per_job = min_iteration_per_job
self._bohb_kwargs = bohb_kargs or {}
self._param_iterator = None
self._namespace = None
self._bohb = None
self._res = None
def set_optimization_args(self, eta=3, min_budget=None, max_budget=None,
min_points_in_model=None, top_n_percent=15,
num_samples=None, random_fraction=1/3., bandwidth_factor=3,
min_bandwidth=1e-3):
"""
Defaults copied from BOHB constructor, see details in BOHB.__init__
BOHB performs robust and efficient hyperparameter optimization
at scale by combining the speed of Hyperband searches with the
guidance and guarantees of convergence of Bayesian
Optimization. Instead of sampling new configurations at random,
BOHB uses kernel density estimators to select promising candidates.
.. highlight:: none
For reference: ::
@InProceedings{falkner-icml-18,
title = {{BOHB}: Robust and Efficient Hyperparameter Optimization at Scale},
author = {Falkner, Stefan and Klein, Aaron and Hutter, Frank},
booktitle = {Proceedings of the 35th International Conference on Machine Learning},
pages = {1436--1445},
year = {2018},
}
Parameters
----------
eta : float (3)
In each iteration, a complete run of sequential halving is executed. In it,
after evaluating each configuration on the same subset size, only a fraction of
1/eta of them 'advances' to the next round.
Must be greater or equal to 2.
min_budget : float (0.01)
The smallest budget to consider. Needs to be positive!
max_budget : float (1)
The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
min_points_in_model: int (None)
number of observations to start building a KDE. Default 'None' means
dim+1, the bare minimum.
top_n_percent: int (15)
percentage ( between 1 and 99, default 15) of the observations that are considered good.
num_samples: int (64)
number of samples to optimize EI (default 64)
random_fraction: float (1/3.)
fraction of purely random configurations that are sampled from the
prior without the model.
bandwidth_factor: float (3.)
to encourage diversity, the points proposed to optimize EI, are sampled
from a 'widened' KDE where the bandwidth is multiplied by this factor (default: 3)
min_bandwidth: float (1e-3)
to keep diversity, even when all (good) samples have the same value for one of the parameters,
a minimum bandwidth (Default: 1e-3) is used instead of zero.
"""
if min_budget:
self._bohb_kwargs['min_budget'] = min_budget
if max_budget:
self._bohb_kwargs['max_budget'] = max_budget
if num_samples:
self._bohb_kwargs['num_samples'] = num_samples
self._bohb_kwargs['eta'] = eta
self._bohb_kwargs['min_points_in_model'] = min_points_in_model
self._bohb_kwargs['top_n_percent'] = top_n_percent
self._bohb_kwargs['random_fraction'] = random_fraction
self._bohb_kwargs['bandwidth_factor'] = bandwidth_factor
self._bohb_kwargs['min_bandwidth'] = min_bandwidth
def start(self):
# Step 1: Start a nameserver
fake_run_id = 'OptimizerBOHB_{}'.format(time())
self._namespace = hpns.NameServer(run_id=fake_run_id, host='127.0.0.1', port=None)
self._namespace.start()
# we have to scale the budget to the iterations per job, otherwise numbers might be too high
budget_iteration_scale = self._max_iteration_per_job
# Step 2: Start the workers
workers = []
for i in range(self._num_concurrent_workers):
w = TrainsBandsterWorker(optimizer=self,
sleep_interval=int(self.pool_period_minutes*60),
budget_iteration_scale=budget_iteration_scale,
base_task_id=self._base_task_id,
objective=self._objective_metric,
queue_name=self._execution_queue,
nameserver='127.0.0.1', run_id=fake_run_id, id=i)
w.run(background=True)
workers.append(w)
# Step 3: Run an optimizer
self._bohb = BOHB(configspace=self.convert_hyper_parameters_to_cs(),
run_id=fake_run_id,
num_samples=self.total_max_jobs,
min_budget=float(self._min_iteration_per_job)/float(self._max_iteration_per_job),
**self._bohb_kwargs)
self._res = self._bohb.run(n_iterations=self.total_max_jobs, min_n_workers=self._num_concurrent_workers)
# Step 4: if we get here, Shutdown
self.stop()
def stop(self):
# After the optimizer run, we must shutdown the master and the nameserver.
self._bohb.shutdown(shutdown_workers=True)
self._namespace.shutdown()
if not self._res:
return
# Step 5: Analysis
id2config = self._res.get_id2config_mapping()
incumbent = self._res.get_incumbent_id()
all_runs = self._res.get_all_runs()
# Step 6: Print Analysis
print('Best found configuration:', id2config[incumbent]['config'])
print('A total of {} unique configurations where sampled.'.format(len(id2config.keys())))
print('A total of {} runs where executed.'.format(len(self._res.get_all_runs())))
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
print('The run took {:.1f} seconds to complete.'.format(
all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))
def convert_hyper_parameters_to_cs(self):
cs = CS.ConfigurationSpace(seed=self._seed)
for p in self._hyper_parameters:
if isinstance(p, UniformParameterRange):
hp = CSH.UniformFloatHyperparameter(
p.name, lower=p.min_value, upper=p.max_value, log=False, q=p.step_size)
elif isinstance(p, UniformIntegerParameterRange):
hp = CSH.UniformIntegerHyperparameter(
p.name, lower=p.min_value, upper=p.max_value, log=False, q=p.step_size)
elif isinstance(p, DiscreteParameterRange):
hp = CSH.CategoricalHyperparameter(p.name, choices=p.values)
else:
raise ValueError("HyperParameter type {} not supported yet with OptimizerBOHB".format(type(p)))
cs.add_hyperparameter(hp)
return cs

279
trains/automation/job.py Normal file
View File

@ -0,0 +1,279 @@
import hashlib
from datetime import datetime
from logging import getLogger
from time import time, sleep
from ..task import Task
from ..backend_api.services import tasks as tasks_service
from ..backend_api.services import events as events_service
logger = getLogger('trains.automation.job')
class TrainsJob(object):
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, parent=None, **kwargs):
"""
Create a new Task based in a base_task_id with a different set of parameters
:param str base_task_id: base task id to clone from
:param dict parameter_override: dictionary of parameters and values to set fo the cloned task
:param dict task_overrides: Task object specific overrides
:param list tags: additional tags to add to the newly cloned task
:param str parent: Set newly created Task parent task field, default: base_tak_id.
:param dict kwargs: additional Task creation parameters
"""
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
if tags:
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
if parameter_override:
params = self.task.get_parameters_as_dict()
params.update(parameter_override)
self.task.set_parameters_as_dict(params)
if task_overrides:
# todo: make sure it works
# noinspection PyProtectedMember
self.task._edit(task_overrides)
self.task_started = False
self._worker = None
def get_metric(self, title, series):
"""
Retrieve a specific scalar metric from the running Task.
:param str title: Graph title (metric)
:param str series: Series on the specific graph (variant)
:return tuple: min value, max value, last value
"""
title = hashlib.md5(str(title).encode('utf-8')).hexdigest()
series = hashlib.md5(str(series).encode('utf-8')).hexdigest()
metric = 'last_metrics.{}.{}.'.format(title, series)
values = ['min_value', 'max_value', 'value']
metrics = [metric + v for v in values]
res = self.task.send(
tasks_service.GetAllRequest(
id=[self.task.id],
page=0,
page_size=1,
only_fields=['id', ] + metrics
)
)
response = res.wait()
return tuple(response.response_data['tasks'][0]['last_metrics'][title][series][v] for v in values)
def launch(self, queue_name=None):
"""
Send Job for execution on the requested execution queue
:param str queue_name:
"""
try:
Task.enqueue(task=self.task, queue_name=queue_name)
except Exception as ex:
logger.warning(ex)
def abort(self):
"""
Abort currently running job (can be called multiple times)
"""
try:
self.task.stopped()
except Exception as ex:
logger.warning(ex)
def elapsed(self):
"""
Return the time in seconds since job started. Return -1 if job is still pending
:return float: seconds from start
"""
if not self.task_started and str(self.task.status) != Task.TaskStatusEnum.in_progress:
return -1
self.task_started = True
return (datetime.now() - self.task.data.started).timestamp()
def iterations(self):
"""
Return the last iteration value of the current job. -1 if job has not started yet
:return int: Task last iteration
"""
if not self.task_started and self.task.status != Task.TaskStatusEnum.in_progress:
return -1
self.task_started = True
return self.task.get_last_iteration()
def task_id(self):
"""
Return the Task id.
:return str: Task id
"""
return self.task.id
def status(self):
"""
Return the Job Task current status, see Task.TaskStatusEnum
:return str: Task status Task.TaskStatusEnum in string
"""
return self.task.status
def wait(self, timeout=None, pool_period=30.):
"""
Wait until the task is fully executed (i.e. aborted/completed/failed)
:param timeout: maximum time (minutes) to wait for Task to finish
:param pool_period: check task status every pool_period seconds
:return bool: Return True is Task finished.
"""
tic = time()
while timeout is None or time()-tic < timeout*60.:
if self.is_stopped():
return True
sleep(pool_period)
return self.is_stopped()
def get_console_output(self, number_of_reports=1):
"""
Return a list of console outputs reported by the Task.
Returned console outputs are retrieved from the most updated console outputs.
:param int number_of_reports: number of reports to return, default 1, the last (most updated) console output
:return list: List of strings each entry corresponds to one report.
"""
return self.task.get_reported_console_output(number_of_reports=number_of_reports)
def worker(self):
"""
Return the current worker id executing this Job. If job is pending, returns None
:return str: Worker ID (str) executing / executed the job, or None if job is still pending.
"""
if self.is_pending():
return self._worker
if self._worker is None:
# the last console outputs will update the worker
self.get_console_output(number_of_reports=1)
# if we still do not have it, store empty string
if not self._worker:
self._worker = ''
return self._worker
def is_running(self):
"""
Return True if job is currently running (pending is considered False)
:return bool: True iff the task is currently in progress
"""
return self.task.status == Task.TaskStatusEnum.in_progress
def is_stopped(self):
"""
Return True if job is has executed and is not any more
:return bool: True the task is currently one of these states, stopped / completed / failed
"""
return self.task.status in (
Task.TaskStatusEnum.stopped, Task.TaskStatusEnum.completed,
Task.TaskStatusEnum.failed, Task.TaskStatusEnum.published)
def is_pending(self):
"""
Return True if job is waiting for execution
:return bool: True the task is currently is currently queued
"""
return self.task.status in (Task.TaskStatusEnum.queued, Task.TaskStatusEnum.created)
class JobStub(object):
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, **_):
self.task = None
self.base_task_id = base_task_id
self.parameter_override = parameter_override
self.task_overrides = task_overrides
self.tags = tags
self.iteration = -1
self.task_started = None
def launch(self, queue_name=None):
self.iteration = 0
self.task_started = time()
print('launching', self.parameter_override, 'in', queue_name)
def abort(self):
self.task_started = -1
def elapsed(self):
"""
Return the time in seconds since job started. Return -1 if job is still pending
:return float: seconds from start
"""
if self.task_started is None:
return -1
return time() - self.task_started
def iterations(self):
"""
Return the last iteration value of the current job. -1 if job has not started yet
:return int: Task last iteration
"""
if self.task_started is None:
return -1
return self.iteration
def get_metric(self, title, series):
"""
Retrieve a specific scalar metric from the running Task.
:param str title: Graph title (metric)
:param str series: Series on the specific graph (variant)
:return list: min value, max value, last value
"""
return 0, 1.0, 0.123
def task_id(self):
return 'stub'
def worker(self):
return None
def status(self):
return 'in_progress'
def wait(self, timeout=None, pool_period=30.):
"""
Wait for the task to be processed (i.e. aborted/completed/failed)
:param timeout: maximum time (minutes) to wait for Task to finish
:param pool_period: check task status every pool_period seconds
:return bool: Return True is Task finished.
"""
return True
def get_console_output(self, number_of_reports=1):
"""
Return a list of console outputs reported by the Task.
Returned console outputs are retrieved from the most updated console outputs.
:param int number_of_reports: number of reports to return, default 1, the last (most updated) console output
:return list: List of strings each entry corresponds to one report.
"""
return []
def is_running(self):
return self.task_started is not None and self.task_started > 0
def is_stopped(self):
return self.task_started is not None and self.task_started < 0
def is_pending(self):
return self.task_started is None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,263 @@
import sys
from itertools import product
from random import Random as BaseRandom
import numpy as np
class RandomSeed(object):
"""
Base class controlling random sampling for every optimization strategy
"""
_random = BaseRandom(1337)
_seed = 1337
@staticmethod
def set_random_seed(seed=1337):
"""
Set global seed for all hyper parameter strategy random number sampling
:param int seed: random seed
"""
RandomSeed._seed = seed
RandomSeed._random = BaseRandom(seed)
@staticmethod
def get_random_seed():
"""
Get the global seed for all hyper parameter strategy random number sampling
:return int: random seed
"""
return RandomSeed._seed
class Parameter(RandomSeed):
"""
Base Hyper-Parameter optimization object
"""
def __init__(self, name):
self.name = name
def get_value(self):
"""
Return a dict with the Parameter name and a sampled value for the Parameter
:return dict: example, {'answer': 0.42}
"""
pass
def to_list(self):
"""
Return a list of all the valid values of the Parameter
:return list: list of dicts {name: values}
"""
pass
def to_dict(self):
"""
Return a dict representation of the parameter object
:return: dict
"""
serialize = {'__class__': str(self.__class__).split('.')[-1][:-2]}
serialize.update(dict(((k, v.to_dict() if hasattr(v, 'to_dict') else v) for k, v in self.__dict__.items())))
return serialize
@classmethod
def from_dict(cls, a_dict):
"""
Construct parameter object from a dict representation
:return: Parameter
"""
a_dict = a_dict.copy()
a_cls = a_dict.pop('__class__', None)
if not a_cls:
return None
try:
a_cls = getattr(sys.modules[__name__], a_cls)
except AttributeError:
return None
instance = a_cls.__new__(a_cls)
instance.__dict__ = dict((k, cls.from_dict(v) if isinstance(v, dict) and '__class__' in v else v)
for k, v in a_dict.items())
return instance
class UniformParameterRange(Parameter):
"""
Uniform randomly sampled Hyper-Parameter object
"""
def __init__(self, name, min_value, max_value, step_size=None, include_max_value=True):
"""
Create a parameter to be sampled by the SearchStrategy
:param str name: parameter name, should match task hyper-parameter name
:param float min_value: minimum sample to use for uniform random sampling
:param float max_value: maximum sample to use for uniform random sampling
:param float step_size: If not None, set step size (quantization) for value sampling
:param bool include_max_value: if True range includes the max_value (default True)
"""
super(UniformParameterRange, self).__init__(name=name)
self.min_value = float(min_value)
self.max_value = float(max_value)
self.step_size = float(step_size) if step_size is not None else None
self.include_max = include_max_value
def get_value(self):
"""
Return uniformly sampled value based on object sampling definitions
:return dict: {self.name: random value [self.min_value, self.max_value)}
"""
if not self.step_size:
return {self.name: self._random.uniform(self.min_value, self.max_value)}
steps = (self.max_value - self.min_value) / self.step_size
return {self.name: self.min_value + (self._random.randrange(start=0, stop=round(steps)) * self.step_size)}
def to_list(self):
"""
Return a list of all the valid values of the Parameter
if self.step_size is not defined, return 100 points between minmax values
:return list: list of dicts {name: float}
"""
values = np.arange(start=self.min_value, stop=self.max_value,
step=self.step_size or (self.max_value - self.min_value) / 100.,
dtype=type(self.min_value)).tolist()
if self.include_max and (not values or values[-1] < self.max_value):
values.append(self.max_value)
return [{self.name: v} for v in values]
class UniformIntegerParameterRange(Parameter):
"""
Uniform randomly sampled integer Hyper-Parameter object
"""
def __init__(self, name, min_value, max_value, step_size=1, include_max_value=True):
"""
Create a parameter to be sampled by the SearchStrategy
:param str name: parameter name, should match task hyper-parameter name
:param int min_value: minimum sample to use for uniform random sampling
:param int max_value: maximum sample to use for uniform random sampling
:param int step_size: Default step size is 1
:param bool include_max_value: if True range includes the max_value (default True)
"""
super(UniformIntegerParameterRange, self).__init__(name=name)
self.min_value = int(min_value)
self.max_value = int(max_value)
self.step_size = int(step_size) if step_size is not None else None
self.include_max = include_max_value
def get_value(self):
"""
Return uniformly sampled value based on object sampling definitions
:return dict: {self.name: random value [self.min_value, self.max_value)}
"""
return {self.name: self._random.randrange(
start=self.min_value, step=self.step_size,
stop=self.max_value + (0 if not self.include_max else self.step_size))}
def to_list(self):
"""
Return a list of all the valid values of the Parameter
if self.step_size is not defined, return 100 points between minmax values
:return list: list of dicts {name: float}
"""
values = list(range(self.min_value, self.max_value, self.step_size))
if self.include_max and (not values or values[-1] < self.max_value):
values.append(self.max_value)
return [{self.name: v} for v in values]
class DiscreteParameterRange(Parameter):
"""
Discrete randomly sampled Hyper-Parameter object
"""
def __init__(self, name, values=()):
"""
Uniformly sample values form a list of discrete options
:param str name: parameter name, should match task hyper-parameter name
:param list values: list/tuple of valid parameter values to sample from
"""
super(DiscreteParameterRange, self).__init__(name=name)
self.values = values
def get_value(self):
"""
Return uniformly sampled value from the valid list of values
:return dict: {self.name: random entry from self.value}
"""
return {self.name: self._random.choice(self.values)}
def to_list(self):
"""
Return a list of all the valid values of the Parameter
:return list: list of dicts {name: value}
"""
return [{self.name: v} for v in self.values]
class ParameterSet(Parameter):
"""
Discrete randomly sampled Hyper-Parameter object
"""
def __init__(self, parameter_combinations=()):
"""
Uniformly sample values form a list of discrete options (combinations) of parameters
:param list parameter_combinations: list/tuple of valid parameter combinations,
Example: two combinations with three specific parameters per combination:
[ {'opt1': 10, 'arg2': 20, 'arg2': 30},
{'opt2': 11, 'arg2': 22, 'arg2': 33}, ]
Example: Two complex combination each one sampled from a different range
[ {'opt1': UniformParameterRange('arg1',0,1) , 'arg2': 20},
{'opt2': UniformParameterRange('arg1',11,12), 'arg2': 22},]
"""
super(ParameterSet, self).__init__(name=None)
self.values = parameter_combinations
def get_value(self):
"""
Return uniformly sampled value from the valid list of values
:return dict: {self.name: random entry from self.value}
"""
return self._get_value(self._random.choice(self.values))
def to_list(self):
"""
Return a list of all the valid values of the Parameter
:return list: list of dicts {name: value}
"""
combinations = []
for combination in self.values:
single_option = {}
for k, v in combination.items():
if isinstance(v, Parameter):
single_option[k] = v.to_list()
else:
single_option[k] = [{k: v}, ]
for state in product(*single_option.values()):
combinations.append(dict(kv for d in state for kv in d.items()))
return combinations
@staticmethod
def _get_value(combination):
value_dict = {}
for k, v in combination.items():
if isinstance(v, Parameter):
value_dict.update(v.get_value())
else:
value_dict[k] = v
return value_dict

View File

@ -161,6 +161,8 @@ class Session(TokenManager):
api_version = token_dict.get('api_version')
if not api_version:
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
if token_dict.get('server_version'):
Session._client.append(('trains-server', token_dict.get('server_version'), ))
Session.api_version = str(api_version)
except (jwt.DecodeError, ValueError):

View File

@ -11,6 +11,10 @@ class EnvEntry(Entry):
conversions[bool] = text_to_bool
return conversions
def __init__(self, key, *more_keys, **kwargs):
super(EnvEntry, self).__init__(key, *more_keys, **kwargs)
self._ignore_errors = kwargs.pop('ignore_errors', False)
def _get(self, key):
value = getenv(key, "").strip()
return value or NotSet
@ -22,7 +26,8 @@ class EnvEntry(Entry):
return "env:{}".format(super(EnvEntry, self).__str__())
def error(self, message):
print("Environment configuration: {}".format(message))
if not self._ignore_errors:
print("Environment configuration: {}".format(message))
def exists(self):
return any(key for key in self.keys if getenv(key) is not None)

View File

@ -1375,7 +1375,11 @@ class OutputModel(BaseModel):
self._floating_data = None
# now we have to update the creator task so it points to us
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
if self._task.status not in (self._task.TaskStatusEnum.created, self._task.TaskStatusEnum.in_progress):
self._log.warning('Could not update last created model in Task {}, '
'Task status \'{}\' cannot be updated'.format(self._task.id, self._task.status))
else:
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
return self._base_model

View File

@ -27,11 +27,17 @@ class CacheManager(object):
def get_local_copy(self, remote_url):
helper = StorageHelper.get(remote_url)
if not helper:
raise ValueError("Remote storage not supported: {}".format(remote_url))
raise ValueError("Storage access failed: {}".format(remote_url))
# check if we need to cache the file
direct_access = helper._driver.get_direct_access(remote_url)
try:
direct_access = helper._driver.get_direct_access(remote_url)
except (OSError, ValueError):
LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url))
return None
if direct_access:
return direct_access
# check if we already have the file in our cache
cached_file, cached_size = self._get_cache_file(remote_url)
if cached_size is not None:

View File

@ -263,7 +263,8 @@ class StorageHelper(object):
log.error(str(ex))
return None
except Exception as ex:
log.error("Failed credentials for {}. Reason: {}".format(base_url or url, ex))
log.error("Failed creating storage object {} Reason: {}".format(
base_url or url, ex))
return None
cls._helpers[instance_key] = instance
@ -361,14 +362,16 @@ class StorageHelper(object):
# url2pathname is specifically intended to operate on (urlparse result).path
# and returns a cross-platform compatible result
driver_uri = url2pathname(url)
if Path(driver_uri).is_file():
driver_uri = str(Path(driver_uri).parent)
elif not Path(driver_uri).exists():
# assume a folder and create
Path(driver_uri).mkdir(parents=True, exist_ok=True)
path_driver_uri = Path(driver_uri)
# if path_driver_uri.is_file():
# driver_uri = str(path_driver_uri.parent)
# elif not path_driver_uri.exists():
# # assume a folder and create
# # Path(driver_uri).mkdir(parents=True, exist_ok=True)
# pass
self._driver = _FileStorageDriver(driver_uri)
self._container = self._driver.get_container(container_name='.')
self._driver = _FileStorageDriver(str(path_driver_uri.root))
self._container = None
@classmethod
def terminate_uploads(cls, force=True, timeout=2.0):
@ -910,8 +913,8 @@ class StorageHelper(object):
except Exception as e:
self._log.error("Calling upload callback when starting upload: %s" % str(e))
if verbose:
msg = 'Starting upload: {} => {}{}'.format(src_path, self._container.name if self._container else '',
object_name)
msg = 'Starting upload: {} => {}{}'.format(
src_path, self._container.name if self._container else '', object_name)
if object_name.startswith('file://') or object_name.startswith('/'):
self._log.debug(msg)
else:
@ -946,7 +949,8 @@ class StorageHelper(object):
def _get_object(self, path):
object_name = self._normalize_object_name(path)
try:
return self._driver.get_object(container_name=self._container.name, object_name=object_name)
return self._driver.get_object(
container_name=self._container.name if self._container else '', object_name=object_name)
except ConnectionError as ex:
raise DownloadError
except Exception as e:
@ -1757,9 +1761,6 @@ class _FileStorageDriver(_Driver):
# Use the key as the path to the storage
self.base_path = key
if not os.path.isdir(self.base_path):
raise ValueError('The base path is not a directory')
def _make_path(self, path, ignore_existing=True):
"""
Create a path by checking if it already exists
@ -1781,7 +1782,7 @@ class _FileStorageDriver(_Driver):
"""
if '/' in container_name or '\\' in container_name:
raise ValueError(value=None, driver=self, container_name=container_name)
raise ValueError("Container name \"{}\" cannot contain \\ or / ".format(container_name))
def _make_container(self, container_name):
"""
@ -1793,7 +1794,7 @@ class _FileStorageDriver(_Driver):
:return: Container instance.
:rtype: :class:`Container`
"""
container_name = container_name or '.'
self._check_container_name(container_name)
full_path = os.path.realpath(os.path.join(self.base_path, container_name))
@ -1801,14 +1802,15 @@ class _FileStorageDriver(_Driver):
try:
stat = os.stat(full_path)
if not os.path.isdir(full_path):
raise OSError('Target path is not a directory')
raise OSError("Target path \"{}\" is not a directory".format(full_path))
except OSError:
raise ValueError(value=None, driver=self, container_name=container_name)
raise OSError("Target path \"{}\" is not accessible or does not exist".format(full_path))
extra = {}
extra['creation_time'] = stat.st_ctime
extra['access_time'] = stat.st_atime
extra['modify_time'] = stat.st_mtime
extra = {
'creation_time': stat.st_ctime,
'access_time': stat.st_atime,
'modify_time': stat.st_mtime,
}
return self._Container(name=container_name, extra=extra, driver=self)
@ -1826,20 +1828,21 @@ class _FileStorageDriver(_Driver):
:rtype: :class:`Object`
"""
full_path = os.path.realpath(os.path.join(self.base_path, container.name, object_name))
full_path = os.path.realpath(os.path.join(self.base_path, container.name if container else '.', object_name))
if os.path.isdir(full_path):
raise ValueError(value=None, driver=self, object_name=object_name)
raise ValueError("Target path \"{}\" already exist".format(full_path))
try:
stat = os.stat(full_path)
except Exception:
raise ValueError(value=None, driver=self, object_name=object_name)
raise ValueError("Cannot access target path \"{}\"".format(full_path))
extra = {}
extra['creation_time'] = stat.st_ctime
extra['access_time'] = stat.st_atime
extra['modify_time'] = stat.st_mtime
extra = {
'creation_time': stat.st_ctime,
'access_time': stat.st_atime,
'modify_time': stat.st_mtime,
}
return self.Object(name=object_name, size=stat.st_size, extra=extra,
driver=self, container=container, hash=None, meta_data=None)
@ -1914,10 +1917,10 @@ class _FileStorageDriver(_Driver):
:return: A CDN URL for this container.
:rtype: ``str``
"""
path = os.path.realpath(os.path.join(self.base_path, container.name))
path = os.path.realpath(os.path.join(self.base_path, container.name if container else '.'))
if check and not os.path.isdir(path):
raise ValueError(value=None, driver=self, container_name=container.name)
raise ValueError("Target path \"{}\" does not exist".format(path))
return path
@ -1977,9 +1980,7 @@ class _FileStorageDriver(_Driver):
base_name = os.path.basename(destination_path)
if not base_name and not os.path.exists(destination_path):
raise ValueError(
value='Path %s does not exist' % (destination_path),
driver=self)
raise ValueError('Path \"%s\" does not exist'.format(destination_path))
if not base_name:
file_path = os.path.join(destination_path, obj.name)
@ -1987,7 +1988,7 @@ class _FileStorageDriver(_Driver):
file_path = destination_path
if os.path.exists(file_path) and not overwrite_existing:
raise ValueError('File %s already exists, but ' % (file_path) + 'overwrite_existing=False')
raise ValueError('File \"{}\" already exists, but overwrite_existing=False'.format(file_path))
try:
shutil.copy(obj_path, file_path)
@ -2146,7 +2147,7 @@ class _FileStorageDriver(_Driver):
:return: :class:`Container` instance on success.
:rtype: :class:`Container`
"""
container_name = container_name or '.'
self._check_container_name(container_name)
path = os.path.join(self.base_path, container_name)
@ -2156,13 +2157,13 @@ class _FileStorageDriver(_Driver):
except OSError:
exp = sys.exc_info()[1]
if exp.errno == errno.EEXIST:
raise ValueError('Container %s with this name already exists. The name '
raise ValueError('Container \"{}\" with this name already exists. The name '
'must be unique among all the containers in the '
'system' % container_name)
'system'.format(container_name))
else:
raise ValueError( 'Error creating container %s' % container_name)
raise ValueError('Error creating container \"{}\"'.format(container_name))
except Exception:
raise ValueError('Error creating container %s' % container_name)
raise ValueError('Error creating container \"{}\"'.format(container_name))
return self._make_container(container_name)
@ -2179,7 +2180,7 @@ class _FileStorageDriver(_Driver):
# Check if there are any objects inside this
for obj in self._get_objects(container):
raise ValueError(value='Container %s is not empty' % container.name)
raise ValueError('Container \"%s\" is not empty'.format(container.name))
path = self.get_container_cdn_url(container, check=True)
@ -2259,7 +2260,10 @@ class _FileStorageDriver(_Driver):
# this will always make sure we have full path and file:// prefix
full_url = StorageHelper.conform_url(remote_path)
# now get rid of the file:// prefix
return Path(full_url[7:]).as_posix()
path = Path(full_url[7:])
if not path.exists():
raise ValueError("Requested path does not exist: {}".format(path))
return path.as_posix()
def test_upload(self, test_path, config, **kwargs):
return True

View File

@ -9,13 +9,23 @@ def get_config_object_matcher(**patterns):
raise ValueError('Unsupported object matcher (expecting string): %s'
% ', '.join('%s=%s' % (k, v) for k, v in unsupported.items()))
# optimize simple patters
starts_with = {k: v.rstrip('*') for k, v in patterns.items() if '*' not in v.rstrip('*') and '?' not in v}
patterns = {k: v for k, v in patterns.items() if v not in starts_with}
def _matcher(**kwargs):
for key, value in kwargs.items():
if not value:
continue
pat = patterns.get(key)
if pat and fnmatch.fnmatch(value, pat):
return True
start = starts_with.get(key)
if start:
if value.startswith(start):
return True
else:
pat = patterns.get(key)
if pat and fnmatch.fnmatch(value, pat):
return True
return _matcher

View File

@ -156,8 +156,6 @@ class Task(_Task):
object.
:return: The current running Task (experiment).
:rtype: Task() object or ``None``
"""
return cls.__main_task
@ -285,8 +283,6 @@ class Task(_Task):
- ``False`` - Do not automatically create.
:return: The main execution Task (Task context).
:rtype: Task object
"""
def verify_defaults_match():
@ -509,8 +505,6 @@ class Task(_Task):
:type task_type: TaskTypeEnum(value)
:return: A new experiment.
:rtype: Task() object
"""
if not project_name:
if not cls.__main_task:
@ -543,8 +537,6 @@ class Task(_Task):
:param str task_name: The name of the Task within ``project_name`` to get.
:return: The Task specified by Id, or project name / experiment name combination.
:rtype: Task object
"""
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
@ -574,8 +566,6 @@ class Task(_Task):
If None is passed, returns all tasks within the project
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
:return: The Tasks specified by the parameter combinations (see the parameters).
:rtype: List of Task objects
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {}))
@ -604,8 +594,6 @@ class Task(_Task):
A read-only dictionary of Task artifacts (name, artifact).
:return: The artifacts.
:rtype: dict
"""
if not Session.check_min_api_version('2.3'):
return ReadOnlyDict()
@ -649,8 +637,6 @@ class Task(_Task):
If ``None``, the new task inherits the original Task's project. (Optional)
:return: The new cloned Task (experiment).
:rtype: Task object
"""
assert isinstance(source_task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
@ -685,7 +671,7 @@ class Task(_Task):
:param str queue_name: The name of the queue. If not specified, then ``queue_id`` must be specified.
:param str queue_id: The Id of the queue. If not specified, then ``queue_name`` must be specified.
:return: An enqueue response.
:return: An enqueue JSON response.
.. code-block:: javascript
@ -713,8 +699,6 @@ class Task(_Task):
- ``last_update`` - The last Task update time, including Task creation, update, change, or events for
this task (ISO 8601 format).
- ``execution.queue`` - The Id of the queue where the Task is enqueued. ``null`` indicates not enqueued.
:rtype: JSON
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
@ -746,7 +730,7 @@ class Task(_Task):
:param task: The Task to dequeue. Specify a Task object or Task Id.
:type task: Task object / str
:return: A dequeue response.
:return: A dequeue JSON response.
.. code-block:: javascript
@ -774,8 +758,6 @@ class Task(_Task):
- ``execution.queue`` - The Id of the queue where the Task is enqueued. ``null`` indicates not enqueued.
- ``updated`` - The number of Tasks updated (an integer or ``null``).
:rtype: JSON
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
@ -871,10 +853,8 @@ class Task(_Task):
:type configuration: dict, pathlib.Path/str
:return: If a dictictonary is specified, then a dictionary is returned. If pathlib2.Path / string is
specified, then a path to a local configuration file is returned.
:rtype: Configuration object
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
specified, then a path to a local configuration file is returned. Configuration object.
"""
if not isinstance(configuration, (dict, Path, six.string_types)):
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
@ -936,9 +916,7 @@ class Task(_Task):
'person': 1
}
:return: The label enumeration dictionary.
:rtype: JSON
:return: The label enumeration dictionary (JSON).
"""
if not isinstance(enumeration, dict):
raise ValueError("connect_label_enumeration supports only `dict` type, "
@ -960,8 +938,6 @@ class Task(_Task):
**Trains Web-App (UI)**.
:return: The Logger for the Task (experiment).
:rtype: Logger object
"""
return self._get_logger()
@ -1106,8 +1082,6 @@ class Task(_Task):
After calling ``get_registered_artifacts``, you can still modify the registered artifacts.
:return: The registered (dynamically synchronized) artifacts.
:rtype: dict
"""
return self._artifacts_manager.registered_artifacts
@ -1149,8 +1123,6 @@ class Task(_Task):
- ``True`` - Upload succeeded.
- ``False`` - Upload failed.
:rtype: bool
:raise: If the artifact object type is not supported, raise a ``ValueError``.
"""
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
@ -1183,8 +1155,6 @@ class Task(_Task):
- ``True`` - Is the main execution Task.
- ``False`` - Is not the main execution Task.
:rtype: bool
"""
return self.is_main_task()
@ -1205,8 +1175,6 @@ class Task(_Task):
- ``True`` - Is the main execution Task.
- ``False`` - Is not the main execution Task.
:rtype: bool
"""
return self is self.__main_task
@ -1263,8 +1231,6 @@ class Task(_Task):
sends a request to the **Trains Server** (backend).
:return: The last reported iteration number.
:rtype: int
"""
self._reload_last_iteration()
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
@ -1320,8 +1286,6 @@ class Task(_Task):
}
:return: The last scalar metrics.
:rtype: dict
"""
self.reload()
metrics = self.data.last_metrics
@ -1339,8 +1303,6 @@ class Task(_Task):
.. note::
The values are not parsed. They are returned as is.
:rtype: str
"""
return naive_nested_from_flat_dictionary(self.get_parameters())

View File

@ -27,10 +27,13 @@ class CheckPackageUpdates(object):
cls._package_version_checked = True
client, version = Session._client[0]
version = Version(version)
is_demo = 'https://demoapi.trains.allegro.ai/'.startswith(Session.get_api_server_host())
update_server_releases = requests.get(
'https://updates.trains.allegro.ai/updates',
json={"versions": {c: str(v) for c, v in Session._client}},
json={"demo": is_demo,
"versions": {c: str(v) for c, v in Session._client},
"CI": str(os.environ.get('CI', ''))},
timeout=3.0
)
@ -43,6 +46,10 @@ class CheckPackageUpdates(object):
if "version" not in client_answer:
return None
# do not output upgrade message if we are running inside a CI process.
if EnvEntry("CI", type=bool, ignore_errors=True).get():
return None
latest_version = Version(client_answer["version"])
if version >= latest_version: