mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add automation support including hyper-parameters optimization
This commit is contained in:
parent
b457b9aaad
commit
95105cbe6a
@ -16,4 +16,4 @@ params['Example_Param'] = 1
|
||||
task.connect(params)
|
||||
|
||||
# Print the value to demonstrate it is the value is set by the initiating task.
|
||||
print ("Example_Param is", params['Example_Param'])
|
||||
print("Example_Param is", params['Example_Param'])
|
||||
|
||||
@ -75,6 +75,7 @@ try:
|
||||
},
|
||||
index=['falcon', 'dog', 'spider', 'fish']
|
||||
)
|
||||
df.index.name = 'id'
|
||||
logger.report_table("test table pd", "PD with index", 1, table_plot=df)
|
||||
|
||||
# Report table - CSV from path
|
||||
|
||||
3
trains/automation/__init__.py
Normal file
3
trains/automation/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .parameters import UniformParameterRange, DiscreteParameterRange, UniformIntegerParameterRange, ParameterSet
|
||||
from .optimization import GridSearch, RandomSearch, HyperParameterOptimizer, Objective
|
||||
from .job import TrainsJob
|
||||
1
trains/automation/hpbandster/__init__.py
Normal file
1
trains/automation/hpbandster/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .bandster import OptimizerBOHB
|
||||
240
trains/automation/hpbandster/bandster.py
Normal file
240
trains/automation/hpbandster/bandster.py
Normal file
@ -0,0 +1,240 @@
|
||||
from time import sleep, time
|
||||
from ..parameters import DiscreteParameterRange, UniformParameterRange, RandomSeed, UniformIntegerParameterRange
|
||||
from ..optimization import Objective, SearchStrategy
|
||||
from ...task import Task
|
||||
|
||||
try:
|
||||
from hpbandster.core.worker import Worker
|
||||
from hpbandster.optimizers import BOHB
|
||||
import hpbandster.core.nameserver as hpns
|
||||
import ConfigSpace as CS
|
||||
import ConfigSpace.hyperparameters as CSH
|
||||
Task.add_requirements('hpbandster')
|
||||
except ImportError:
|
||||
raise ValueError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
|
||||
"install with: pip install hpbandster")
|
||||
|
||||
|
||||
class TrainsBandsterWorker(Worker):
|
||||
def __init__(self, *args, optimizer, base_task_id, queue_name, objective,
|
||||
sleep_interval=0, budget_iteration_scale=1., **kwargs):
|
||||
super(TrainsBandsterWorker, self).__init__(*args, **kwargs)
|
||||
self.optimizer = optimizer
|
||||
self.base_task_id = base_task_id
|
||||
self.queue_name = queue_name
|
||||
self.objective = objective
|
||||
self.sleep_interval = sleep_interval
|
||||
self.budget_iteration_scale = budget_iteration_scale
|
||||
self._current_job = None
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
"""
|
||||
Simple example for a compute function
|
||||
The loss is just a the config + some noise (that decreases with the budget)
|
||||
For dramatization, the function can sleep for a given interval to emphasizes
|
||||
the speed ups achievable with parallel workers.
|
||||
Args:
|
||||
config: dictionary containing the sampled configurations by the optimizer
|
||||
budget: (float) amount of time/epochs/etc. the model can use to train.
|
||||
We assume budget is iteration, as time might not be stable from machine to machine.
|
||||
Returns:
|
||||
dictionary with mandatory fields:
|
||||
'loss' (scalar)
|
||||
'info' (dict)
|
||||
"""
|
||||
self._current_job = self.optimizer.helper_create_job(self.base_task_id, parameter_override=config)
|
||||
self.optimizer._current_jobs.append(self._current_job)
|
||||
self._current_job.launch(self.queue_name)
|
||||
iteration_value = None
|
||||
while not self._current_job.is_stopped():
|
||||
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(self._current_job)
|
||||
if iteration_value and iteration_value[0] >= self.budget_iteration_scale * budget:
|
||||
self._current_job.abort()
|
||||
break
|
||||
sleep(self.sleep_interval)
|
||||
|
||||
result = {
|
||||
# this is the a mandatory field to run hyperband
|
||||
# remember: HpBandSter always minimizes!
|
||||
'loss': float(self.objective.get_normalized_objective(self._current_job)*-1.),
|
||||
# can be used for any user-defined information - also mandatory
|
||||
'info': self._current_job.task_id()
|
||||
}
|
||||
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value))
|
||||
self.optimizer._current_jobs.remove(self._current_job)
|
||||
return result
|
||||
|
||||
|
||||
class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
def __init__(self, base_task_id, hyper_parameters, objective_metric,
|
||||
execution_queue, num_concurrent_workers, min_iteration_per_job, max_iteration_per_job, total_max_jobs,
|
||||
pool_period_min=2.0, max_job_execution_minutes=None, **bohb_kargs):
|
||||
"""
|
||||
Initialize a search strategy optimizer
|
||||
|
||||
:param str base_task_id: Task ID (str)
|
||||
:param list hyper_parameters: list of Parameter objects to optimize over
|
||||
:param Objective objective_metric: Objective metric to maximize / minimize
|
||||
:param str execution_queue: execution queue to use for launching Tasks (experiments).
|
||||
:param int num_concurrent_workers: Limit number of concurrent running machines
|
||||
:param float min_iteration_per_job: minimum number of iterations for a job to run.
|
||||
:param int max_iteration_per_job: number of iteration per job
|
||||
:param int total_max_jobs: total maximum job for the optimization process. Default None, unlimited
|
||||
:param float pool_period_min: time in minutes between two consecutive pools
|
||||
:param float max_job_execution_minutes: maximum time per single job in minutes, if exceeded job is aborted
|
||||
:param ** bohb_kargs: arguments passed directly yo the BOHB object
|
||||
"""
|
||||
super(OptimizerBOHB, self).__init__(
|
||||
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
|
||||
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
|
||||
pool_period_min=pool_period_min, max_job_execution_minutes=max_job_execution_minutes,
|
||||
total_max_jobs=total_max_jobs)
|
||||
self._max_iteration_per_job = max_iteration_per_job
|
||||
self._min_iteration_per_job = min_iteration_per_job
|
||||
self._bohb_kwargs = bohb_kargs or {}
|
||||
self._param_iterator = None
|
||||
self._namespace = None
|
||||
self._bohb = None
|
||||
self._res = None
|
||||
|
||||
def set_optimization_args(self, eta=3, min_budget=None, max_budget=None,
|
||||
min_points_in_model=None, top_n_percent=15,
|
||||
num_samples=None, random_fraction=1/3., bandwidth_factor=3,
|
||||
min_bandwidth=1e-3):
|
||||
"""
|
||||
Defaults copied from BOHB constructor, see details in BOHB.__init__
|
||||
|
||||
BOHB performs robust and efficient hyperparameter optimization
|
||||
at scale by combining the speed of Hyperband searches with the
|
||||
guidance and guarantees of convergence of Bayesian
|
||||
Optimization. Instead of sampling new configurations at random,
|
||||
BOHB uses kernel density estimators to select promising candidates.
|
||||
|
||||
.. highlight:: none
|
||||
|
||||
For reference: ::
|
||||
|
||||
@InProceedings{falkner-icml-18,
|
||||
title = {{BOHB}: Robust and Efficient Hyperparameter Optimization at Scale},
|
||||
author = {Falkner, Stefan and Klein, Aaron and Hutter, Frank},
|
||||
booktitle = {Proceedings of the 35th International Conference on Machine Learning},
|
||||
pages = {1436--1445},
|
||||
year = {2018},
|
||||
}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eta : float (3)
|
||||
In each iteration, a complete run of sequential halving is executed. In it,
|
||||
after evaluating each configuration on the same subset size, only a fraction of
|
||||
1/eta of them 'advances' to the next round.
|
||||
Must be greater or equal to 2.
|
||||
min_budget : float (0.01)
|
||||
The smallest budget to consider. Needs to be positive!
|
||||
max_budget : float (1)
|
||||
The largest budget to consider. Needs to be larger than min_budget!
|
||||
The budgets will be geometrically distributed
|
||||
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
|
||||
min_points_in_model: int (None)
|
||||
number of observations to start building a KDE. Default 'None' means
|
||||
dim+1, the bare minimum.
|
||||
top_n_percent: int (15)
|
||||
percentage ( between 1 and 99, default 15) of the observations that are considered good.
|
||||
num_samples: int (64)
|
||||
number of samples to optimize EI (default 64)
|
||||
random_fraction: float (1/3.)
|
||||
fraction of purely random configurations that are sampled from the
|
||||
prior without the model.
|
||||
bandwidth_factor: float (3.)
|
||||
to encourage diversity, the points proposed to optimize EI, are sampled
|
||||
from a 'widened' KDE where the bandwidth is multiplied by this factor (default: 3)
|
||||
min_bandwidth: float (1e-3)
|
||||
to keep diversity, even when all (good) samples have the same value for one of the parameters,
|
||||
a minimum bandwidth (Default: 1e-3) is used instead of zero.
|
||||
"""
|
||||
if min_budget:
|
||||
self._bohb_kwargs['min_budget'] = min_budget
|
||||
if max_budget:
|
||||
self._bohb_kwargs['max_budget'] = max_budget
|
||||
if num_samples:
|
||||
self._bohb_kwargs['num_samples'] = num_samples
|
||||
self._bohb_kwargs['eta'] = eta
|
||||
self._bohb_kwargs['min_points_in_model'] = min_points_in_model
|
||||
self._bohb_kwargs['top_n_percent'] = top_n_percent
|
||||
self._bohb_kwargs['random_fraction'] = random_fraction
|
||||
self._bohb_kwargs['bandwidth_factor'] = bandwidth_factor
|
||||
self._bohb_kwargs['min_bandwidth'] = min_bandwidth
|
||||
|
||||
def start(self):
|
||||
# Step 1: Start a nameserver
|
||||
fake_run_id = 'OptimizerBOHB_{}'.format(time())
|
||||
self._namespace = hpns.NameServer(run_id=fake_run_id, host='127.0.0.1', port=None)
|
||||
self._namespace.start()
|
||||
|
||||
# we have to scale the budget to the iterations per job, otherwise numbers might be too high
|
||||
budget_iteration_scale = self._max_iteration_per_job
|
||||
|
||||
# Step 2: Start the workers
|
||||
workers = []
|
||||
for i in range(self._num_concurrent_workers):
|
||||
w = TrainsBandsterWorker(optimizer=self,
|
||||
sleep_interval=int(self.pool_period_minutes*60),
|
||||
budget_iteration_scale=budget_iteration_scale,
|
||||
base_task_id=self._base_task_id,
|
||||
objective=self._objective_metric,
|
||||
queue_name=self._execution_queue,
|
||||
nameserver='127.0.0.1', run_id=fake_run_id, id=i)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
# Step 3: Run an optimizer
|
||||
self._bohb = BOHB(configspace=self.convert_hyper_parameters_to_cs(),
|
||||
run_id=fake_run_id,
|
||||
num_samples=self.total_max_jobs,
|
||||
min_budget=float(self._min_iteration_per_job)/float(self._max_iteration_per_job),
|
||||
**self._bohb_kwargs)
|
||||
self._res = self._bohb.run(n_iterations=self.total_max_jobs, min_n_workers=self._num_concurrent_workers)
|
||||
|
||||
# Step 4: if we get here, Shutdown
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
# After the optimizer run, we must shutdown the master and the nameserver.
|
||||
self._bohb.shutdown(shutdown_workers=True)
|
||||
self._namespace.shutdown()
|
||||
|
||||
if not self._res:
|
||||
return
|
||||
|
||||
# Step 5: Analysis
|
||||
id2config = self._res.get_id2config_mapping()
|
||||
incumbent = self._res.get_incumbent_id()
|
||||
all_runs = self._res.get_all_runs()
|
||||
|
||||
# Step 6: Print Analysis
|
||||
print('Best found configuration:', id2config[incumbent]['config'])
|
||||
print('A total of {} unique configurations where sampled.'.format(len(id2config.keys())))
|
||||
print('A total of {} runs where executed.'.format(len(self._res.get_all_runs())))
|
||||
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
|
||||
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
|
||||
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
|
||||
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
|
||||
print('The run took {:.1f} seconds to complete.'.format(
|
||||
all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))
|
||||
|
||||
def convert_hyper_parameters_to_cs(self):
|
||||
cs = CS.ConfigurationSpace(seed=self._seed)
|
||||
for p in self._hyper_parameters:
|
||||
if isinstance(p, UniformParameterRange):
|
||||
hp = CSH.UniformFloatHyperparameter(
|
||||
p.name, lower=p.min_value, upper=p.max_value, log=False, q=p.step_size)
|
||||
elif isinstance(p, UniformIntegerParameterRange):
|
||||
hp = CSH.UniformIntegerHyperparameter(
|
||||
p.name, lower=p.min_value, upper=p.max_value, log=False, q=p.step_size)
|
||||
elif isinstance(p, DiscreteParameterRange):
|
||||
hp = CSH.CategoricalHyperparameter(p.name, choices=p.values)
|
||||
else:
|
||||
raise ValueError("HyperParameter type {} not supported yet with OptimizerBOHB".format(type(p)))
|
||||
cs.add_hyperparameter(hp)
|
||||
|
||||
return cs
|
||||
279
trains/automation/job.py
Normal file
279
trains/automation/job.py
Normal file
@ -0,0 +1,279 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
from time import time, sleep
|
||||
|
||||
from ..task import Task
|
||||
from ..backend_api.services import tasks as tasks_service
|
||||
from ..backend_api.services import events as events_service
|
||||
|
||||
logger = getLogger('trains.automation.job')
|
||||
|
||||
|
||||
class TrainsJob(object):
|
||||
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, parent=None, **kwargs):
|
||||
"""
|
||||
Create a new Task based in a base_task_id with a different set of parameters
|
||||
|
||||
:param str base_task_id: base task id to clone from
|
||||
:param dict parameter_override: dictionary of parameters and values to set fo the cloned task
|
||||
:param dict task_overrides: Task object specific overrides
|
||||
:param list tags: additional tags to add to the newly cloned task
|
||||
:param str parent: Set newly created Task parent task field, default: base_tak_id.
|
||||
:param dict kwargs: additional Task creation parameters
|
||||
"""
|
||||
self.task = Task.clone(base_task_id, parent=parent or base_task_id, **kwargs)
|
||||
if tags:
|
||||
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
|
||||
if parameter_override:
|
||||
params = self.task.get_parameters_as_dict()
|
||||
params.update(parameter_override)
|
||||
self.task.set_parameters_as_dict(params)
|
||||
if task_overrides:
|
||||
# todo: make sure it works
|
||||
# noinspection PyProtectedMember
|
||||
self.task._edit(task_overrides)
|
||||
self.task_started = False
|
||||
self._worker = None
|
||||
|
||||
def get_metric(self, title, series):
|
||||
"""
|
||||
Retrieve a specific scalar metric from the running Task.
|
||||
|
||||
:param str title: Graph title (metric)
|
||||
:param str series: Series on the specific graph (variant)
|
||||
:return tuple: min value, max value, last value
|
||||
"""
|
||||
title = hashlib.md5(str(title).encode('utf-8')).hexdigest()
|
||||
series = hashlib.md5(str(series).encode('utf-8')).hexdigest()
|
||||
metric = 'last_metrics.{}.{}.'.format(title, series)
|
||||
values = ['min_value', 'max_value', 'value']
|
||||
metrics = [metric + v for v in values]
|
||||
|
||||
res = self.task.send(
|
||||
tasks_service.GetAllRequest(
|
||||
id=[self.task.id],
|
||||
page=0,
|
||||
page_size=1,
|
||||
only_fields=['id', ] + metrics
|
||||
)
|
||||
)
|
||||
response = res.wait()
|
||||
|
||||
return tuple(response.response_data['tasks'][0]['last_metrics'][title][series][v] for v in values)
|
||||
|
||||
def launch(self, queue_name=None):
|
||||
"""
|
||||
Send Job for execution on the requested execution queue
|
||||
|
||||
:param str queue_name:
|
||||
"""
|
||||
try:
|
||||
Task.enqueue(task=self.task, queue_name=queue_name)
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
|
||||
def abort(self):
|
||||
"""
|
||||
Abort currently running job (can be called multiple times)
|
||||
"""
|
||||
try:
|
||||
self.task.stopped()
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
|
||||
def elapsed(self):
|
||||
"""
|
||||
Return the time in seconds since job started. Return -1 if job is still pending
|
||||
|
||||
:return float: seconds from start
|
||||
"""
|
||||
if not self.task_started and str(self.task.status) != Task.TaskStatusEnum.in_progress:
|
||||
return -1
|
||||
self.task_started = True
|
||||
return (datetime.now() - self.task.data.started).timestamp()
|
||||
|
||||
def iterations(self):
|
||||
"""
|
||||
Return the last iteration value of the current job. -1 if job has not started yet
|
||||
|
||||
:return int: Task last iteration
|
||||
"""
|
||||
if not self.task_started and self.task.status != Task.TaskStatusEnum.in_progress:
|
||||
return -1
|
||||
self.task_started = True
|
||||
return self.task.get_last_iteration()
|
||||
|
||||
def task_id(self):
|
||||
"""
|
||||
Return the Task id.
|
||||
|
||||
:return str: Task id
|
||||
"""
|
||||
return self.task.id
|
||||
|
||||
def status(self):
|
||||
"""
|
||||
Return the Job Task current status, see Task.TaskStatusEnum
|
||||
|
||||
:return str: Task status Task.TaskStatusEnum in string
|
||||
"""
|
||||
return self.task.status
|
||||
|
||||
def wait(self, timeout=None, pool_period=30.):
|
||||
"""
|
||||
Wait until the task is fully executed (i.e. aborted/completed/failed)
|
||||
|
||||
:param timeout: maximum time (minutes) to wait for Task to finish
|
||||
:param pool_period: check task status every pool_period seconds
|
||||
:return bool: Return True is Task finished.
|
||||
"""
|
||||
tic = time()
|
||||
while timeout is None or time()-tic < timeout*60.:
|
||||
if self.is_stopped():
|
||||
return True
|
||||
sleep(pool_period)
|
||||
|
||||
return self.is_stopped()
|
||||
|
||||
def get_console_output(self, number_of_reports=1):
|
||||
"""
|
||||
Return a list of console outputs reported by the Task.
|
||||
Returned console outputs are retrieved from the most updated console outputs.
|
||||
|
||||
:param int number_of_reports: number of reports to return, default 1, the last (most updated) console output
|
||||
:return list: List of strings each entry corresponds to one report.
|
||||
"""
|
||||
return self.task.get_reported_console_output(number_of_reports=number_of_reports)
|
||||
|
||||
def worker(self):
|
||||
"""
|
||||
Return the current worker id executing this Job. If job is pending, returns None
|
||||
|
||||
:return str: Worker ID (str) executing / executed the job, or None if job is still pending.
|
||||
"""
|
||||
if self.is_pending():
|
||||
return self._worker
|
||||
|
||||
if self._worker is None:
|
||||
# the last console outputs will update the worker
|
||||
self.get_console_output(number_of_reports=1)
|
||||
# if we still do not have it, store empty string
|
||||
if not self._worker:
|
||||
self._worker = ''
|
||||
|
||||
return self._worker
|
||||
|
||||
def is_running(self):
|
||||
"""
|
||||
Return True if job is currently running (pending is considered False)
|
||||
|
||||
:return bool: True iff the task is currently in progress
|
||||
"""
|
||||
return self.task.status == Task.TaskStatusEnum.in_progress
|
||||
|
||||
def is_stopped(self):
|
||||
"""
|
||||
Return True if job is has executed and is not any more
|
||||
|
||||
:return bool: True the task is currently one of these states, stopped / completed / failed
|
||||
"""
|
||||
return self.task.status in (
|
||||
Task.TaskStatusEnum.stopped, Task.TaskStatusEnum.completed,
|
||||
Task.TaskStatusEnum.failed, Task.TaskStatusEnum.published)
|
||||
|
||||
def is_pending(self):
|
||||
"""
|
||||
Return True if job is waiting for execution
|
||||
|
||||
:return bool: True the task is currently is currently queued
|
||||
"""
|
||||
return self.task.status in (Task.TaskStatusEnum.queued, Task.TaskStatusEnum.created)
|
||||
|
||||
|
||||
class JobStub(object):
|
||||
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, **_):
|
||||
self.task = None
|
||||
self.base_task_id = base_task_id
|
||||
self.parameter_override = parameter_override
|
||||
self.task_overrides = task_overrides
|
||||
self.tags = tags
|
||||
self.iteration = -1
|
||||
self.task_started = None
|
||||
|
||||
def launch(self, queue_name=None):
|
||||
self.iteration = 0
|
||||
self.task_started = time()
|
||||
print('launching', self.parameter_override, 'in', queue_name)
|
||||
|
||||
def abort(self):
|
||||
self.task_started = -1
|
||||
|
||||
def elapsed(self):
|
||||
"""
|
||||
Return the time in seconds since job started. Return -1 if job is still pending
|
||||
|
||||
:return float: seconds from start
|
||||
"""
|
||||
if self.task_started is None:
|
||||
return -1
|
||||
return time() - self.task_started
|
||||
|
||||
def iterations(self):
|
||||
"""
|
||||
Return the last iteration value of the current job. -1 if job has not started yet
|
||||
:return int: Task last iteration
|
||||
"""
|
||||
if self.task_started is None:
|
||||
return -1
|
||||
return self.iteration
|
||||
|
||||
def get_metric(self, title, series):
|
||||
"""
|
||||
Retrieve a specific scalar metric from the running Task.
|
||||
|
||||
:param str title: Graph title (metric)
|
||||
:param str series: Series on the specific graph (variant)
|
||||
:return list: min value, max value, last value
|
||||
"""
|
||||
return 0, 1.0, 0.123
|
||||
|
||||
def task_id(self):
|
||||
return 'stub'
|
||||
|
||||
def worker(self):
|
||||
return None
|
||||
|
||||
def status(self):
|
||||
return 'in_progress'
|
||||
|
||||
def wait(self, timeout=None, pool_period=30.):
|
||||
"""
|
||||
Wait for the task to be processed (i.e. aborted/completed/failed)
|
||||
|
||||
:param timeout: maximum time (minutes) to wait for Task to finish
|
||||
:param pool_period: check task status every pool_period seconds
|
||||
:return bool: Return True is Task finished.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_console_output(self, number_of_reports=1):
|
||||
"""
|
||||
Return a list of console outputs reported by the Task.
|
||||
Returned console outputs are retrieved from the most updated console outputs.
|
||||
|
||||
|
||||
:param int number_of_reports: number of reports to return, default 1, the last (most updated) console output
|
||||
:return list: List of strings each entry corresponds to one report.
|
||||
"""
|
||||
return []
|
||||
|
||||
def is_running(self):
|
||||
return self.task_started is not None and self.task_started > 0
|
||||
|
||||
def is_stopped(self):
|
||||
return self.task_started is not None and self.task_started < 0
|
||||
|
||||
def is_pending(self):
|
||||
return self.task_started is None
|
||||
|
||||
1040
trains/automation/optimization.py
Normal file
1040
trains/automation/optimization.py
Normal file
File diff suppressed because it is too large
Load Diff
263
trains/automation/parameters.py
Normal file
263
trains/automation/parameters.py
Normal file
@ -0,0 +1,263 @@
|
||||
import sys
|
||||
from itertools import product
|
||||
from random import Random as BaseRandom
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RandomSeed(object):
|
||||
"""
|
||||
Base class controlling random sampling for every optimization strategy
|
||||
"""
|
||||
_random = BaseRandom(1337)
|
||||
_seed = 1337
|
||||
|
||||
@staticmethod
|
||||
def set_random_seed(seed=1337):
|
||||
"""
|
||||
Set global seed for all hyper parameter strategy random number sampling
|
||||
|
||||
:param int seed: random seed
|
||||
"""
|
||||
RandomSeed._seed = seed
|
||||
RandomSeed._random = BaseRandom(seed)
|
||||
|
||||
@staticmethod
|
||||
def get_random_seed():
|
||||
"""
|
||||
Get the global seed for all hyper parameter strategy random number sampling
|
||||
|
||||
:return int: random seed
|
||||
"""
|
||||
return RandomSeed._seed
|
||||
|
||||
|
||||
class Parameter(RandomSeed):
|
||||
"""
|
||||
Base Hyper-Parameter optimization object
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_value(self):
|
||||
"""
|
||||
Return a dict with the Parameter name and a sampled value for the Parameter
|
||||
|
||||
:return dict: example, {'answer': 0.42}
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_list(self):
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
:return list: list of dicts {name: values}
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Return a dict representation of the parameter object
|
||||
:return: dict
|
||||
"""
|
||||
serialize = {'__class__': str(self.__class__).split('.')[-1][:-2]}
|
||||
serialize.update(dict(((k, v.to_dict() if hasattr(v, 'to_dict') else v) for k, v in self.__dict__.items())))
|
||||
return serialize
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, a_dict):
|
||||
"""
|
||||
Construct parameter object from a dict representation
|
||||
:return: Parameter
|
||||
"""
|
||||
a_dict = a_dict.copy()
|
||||
a_cls = a_dict.pop('__class__', None)
|
||||
if not a_cls:
|
||||
return None
|
||||
try:
|
||||
a_cls = getattr(sys.modules[__name__], a_cls)
|
||||
except AttributeError:
|
||||
return None
|
||||
instance = a_cls.__new__(a_cls)
|
||||
instance.__dict__ = dict((k, cls.from_dict(v) if isinstance(v, dict) and '__class__' in v else v)
|
||||
for k, v in a_dict.items())
|
||||
return instance
|
||||
|
||||
|
||||
class UniformParameterRange(Parameter):
|
||||
"""
|
||||
Uniform randomly sampled Hyper-Parameter object
|
||||
"""
|
||||
|
||||
def __init__(self, name, min_value, max_value, step_size=None, include_max_value=True):
|
||||
"""
|
||||
Create a parameter to be sampled by the SearchStrategy
|
||||
|
||||
:param str name: parameter name, should match task hyper-parameter name
|
||||
:param float min_value: minimum sample to use for uniform random sampling
|
||||
:param float max_value: maximum sample to use for uniform random sampling
|
||||
:param float step_size: If not None, set step size (quantization) for value sampling
|
||||
:param bool include_max_value: if True range includes the max_value (default True)
|
||||
"""
|
||||
super(UniformParameterRange, self).__init__(name=name)
|
||||
self.min_value = float(min_value)
|
||||
self.max_value = float(max_value)
|
||||
self.step_size = float(step_size) if step_size is not None else None
|
||||
self.include_max = include_max_value
|
||||
|
||||
def get_value(self):
|
||||
"""
|
||||
Return uniformly sampled value based on object sampling definitions
|
||||
|
||||
:return dict: {self.name: random value [self.min_value, self.max_value)}
|
||||
"""
|
||||
if not self.step_size:
|
||||
return {self.name: self._random.uniform(self.min_value, self.max_value)}
|
||||
steps = (self.max_value - self.min_value) / self.step_size
|
||||
return {self.name: self.min_value + (self._random.randrange(start=0, stop=round(steps)) * self.step_size)}
|
||||
|
||||
def to_list(self):
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
if self.step_size is not defined, return 100 points between minmax values
|
||||
:return list: list of dicts {name: float}
|
||||
"""
|
||||
values = np.arange(start=self.min_value, stop=self.max_value,
|
||||
step=self.step_size or (self.max_value - self.min_value) / 100.,
|
||||
dtype=type(self.min_value)).tolist()
|
||||
if self.include_max and (not values or values[-1] < self.max_value):
|
||||
values.append(self.max_value)
|
||||
return [{self.name: v} for v in values]
|
||||
|
||||
|
||||
class UniformIntegerParameterRange(Parameter):
|
||||
"""
|
||||
Uniform randomly sampled integer Hyper-Parameter object
|
||||
"""
|
||||
|
||||
def __init__(self, name, min_value, max_value, step_size=1, include_max_value=True):
|
||||
"""
|
||||
Create a parameter to be sampled by the SearchStrategy
|
||||
|
||||
:param str name: parameter name, should match task hyper-parameter name
|
||||
:param int min_value: minimum sample to use for uniform random sampling
|
||||
:param int max_value: maximum sample to use for uniform random sampling
|
||||
:param int step_size: Default step size is 1
|
||||
:param bool include_max_value: if True range includes the max_value (default True)
|
||||
"""
|
||||
super(UniformIntegerParameterRange, self).__init__(name=name)
|
||||
self.min_value = int(min_value)
|
||||
self.max_value = int(max_value)
|
||||
self.step_size = int(step_size) if step_size is not None else None
|
||||
self.include_max = include_max_value
|
||||
|
||||
def get_value(self):
|
||||
"""
|
||||
Return uniformly sampled value based on object sampling definitions
|
||||
|
||||
:return dict: {self.name: random value [self.min_value, self.max_value)}
|
||||
"""
|
||||
return {self.name: self._random.randrange(
|
||||
start=self.min_value, step=self.step_size,
|
||||
stop=self.max_value + (0 if not self.include_max else self.step_size))}
|
||||
|
||||
def to_list(self):
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
if self.step_size is not defined, return 100 points between minmax values
|
||||
:return list: list of dicts {name: float}
|
||||
"""
|
||||
values = list(range(self.min_value, self.max_value, self.step_size))
|
||||
if self.include_max and (not values or values[-1] < self.max_value):
|
||||
values.append(self.max_value)
|
||||
return [{self.name: v} for v in values]
|
||||
|
||||
|
||||
class DiscreteParameterRange(Parameter):
|
||||
"""
|
||||
Discrete randomly sampled Hyper-Parameter object
|
||||
"""
|
||||
|
||||
def __init__(self, name, values=()):
|
||||
"""
|
||||
Uniformly sample values form a list of discrete options
|
||||
|
||||
:param str name: parameter name, should match task hyper-parameter name
|
||||
:param list values: list/tuple of valid parameter values to sample from
|
||||
"""
|
||||
super(DiscreteParameterRange, self).__init__(name=name)
|
||||
self.values = values
|
||||
|
||||
def get_value(self):
|
||||
"""
|
||||
Return uniformly sampled value from the valid list of values
|
||||
|
||||
:return dict: {self.name: random entry from self.value}
|
||||
"""
|
||||
return {self.name: self._random.choice(self.values)}
|
||||
|
||||
def to_list(self):
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
:return list: list of dicts {name: value}
|
||||
"""
|
||||
return [{self.name: v} for v in self.values]
|
||||
|
||||
|
||||
class ParameterSet(Parameter):
|
||||
"""
|
||||
Discrete randomly sampled Hyper-Parameter object
|
||||
"""
|
||||
|
||||
def __init__(self, parameter_combinations=()):
|
||||
"""
|
||||
Uniformly sample values form a list of discrete options (combinations) of parameters
|
||||
|
||||
:param list parameter_combinations: list/tuple of valid parameter combinations,
|
||||
Example: two combinations with three specific parameters per combination:
|
||||
[ {'opt1': 10, 'arg2': 20, 'arg2': 30},
|
||||
{'opt2': 11, 'arg2': 22, 'arg2': 33}, ]
|
||||
Example: Two complex combination each one sampled from a different range
|
||||
[ {'opt1': UniformParameterRange('arg1',0,1) , 'arg2': 20},
|
||||
{'opt2': UniformParameterRange('arg1',11,12), 'arg2': 22},]
|
||||
"""
|
||||
super(ParameterSet, self).__init__(name=None)
|
||||
self.values = parameter_combinations
|
||||
|
||||
def get_value(self):
|
||||
"""
|
||||
Return uniformly sampled value from the valid list of values
|
||||
|
||||
:return dict: {self.name: random entry from self.value}
|
||||
"""
|
||||
return self._get_value(self._random.choice(self.values))
|
||||
|
||||
def to_list(self):
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
:return list: list of dicts {name: value}
|
||||
"""
|
||||
combinations = []
|
||||
for combination in self.values:
|
||||
single_option = {}
|
||||
for k, v in combination.items():
|
||||
if isinstance(v, Parameter):
|
||||
single_option[k] = v.to_list()
|
||||
else:
|
||||
single_option[k] = [{k: v}, ]
|
||||
|
||||
for state in product(*single_option.values()):
|
||||
combinations.append(dict(kv for d in state for kv in d.items()))
|
||||
|
||||
return combinations
|
||||
|
||||
@staticmethod
|
||||
def _get_value(combination):
|
||||
value_dict = {}
|
||||
for k, v in combination.items():
|
||||
if isinstance(v, Parameter):
|
||||
value_dict.update(v.get_value())
|
||||
else:
|
||||
value_dict[k] = v
|
||||
|
||||
return value_dict
|
||||
@ -161,6 +161,8 @@ class Session(TokenManager):
|
||||
api_version = token_dict.get('api_version')
|
||||
if not api_version:
|
||||
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
||||
if token_dict.get('server_version'):
|
||||
Session._client.append(('trains-server', token_dict.get('server_version'), ))
|
||||
|
||||
Session.api_version = str(api_version)
|
||||
except (jwt.DecodeError, ValueError):
|
||||
|
||||
@ -11,6 +11,10 @@ class EnvEntry(Entry):
|
||||
conversions[bool] = text_to_bool
|
||||
return conversions
|
||||
|
||||
def __init__(self, key, *more_keys, **kwargs):
|
||||
super(EnvEntry, self).__init__(key, *more_keys, **kwargs)
|
||||
self._ignore_errors = kwargs.pop('ignore_errors', False)
|
||||
|
||||
def _get(self, key):
|
||||
value = getenv(key, "").strip()
|
||||
return value or NotSet
|
||||
@ -22,7 +26,8 @@ class EnvEntry(Entry):
|
||||
return "env:{}".format(super(EnvEntry, self).__str__())
|
||||
|
||||
def error(self, message):
|
||||
print("Environment configuration: {}".format(message))
|
||||
if not self._ignore_errors:
|
||||
print("Environment configuration: {}".format(message))
|
||||
|
||||
def exists(self):
|
||||
return any(key for key in self.keys if getenv(key) is not None)
|
||||
|
||||
@ -1375,7 +1375,11 @@ class OutputModel(BaseModel):
|
||||
self._floating_data = None
|
||||
|
||||
# now we have to update the creator task so it points to us
|
||||
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
|
||||
if self._task.status not in (self._task.TaskStatusEnum.created, self._task.TaskStatusEnum.in_progress):
|
||||
self._log.warning('Could not update last created model in Task {}, '
|
||||
'Task status \'{}\' cannot be updated'.format(self._task.id, self._task.status))
|
||||
else:
|
||||
self._base_model.update_for_task(task_id=self._task.id, override_model_id=self.id)
|
||||
|
||||
return self._base_model
|
||||
|
||||
|
||||
@ -27,11 +27,17 @@ class CacheManager(object):
|
||||
def get_local_copy(self, remote_url):
|
||||
helper = StorageHelper.get(remote_url)
|
||||
if not helper:
|
||||
raise ValueError("Remote storage not supported: {}".format(remote_url))
|
||||
raise ValueError("Storage access failed: {}".format(remote_url))
|
||||
# check if we need to cache the file
|
||||
direct_access = helper._driver.get_direct_access(remote_url)
|
||||
try:
|
||||
direct_access = helper._driver.get_direct_access(remote_url)
|
||||
except (OSError, ValueError):
|
||||
LoggerRoot.get_base_logger().warning("Failed accessing local file: {}".format(remote_url))
|
||||
return None
|
||||
|
||||
if direct_access:
|
||||
return direct_access
|
||||
|
||||
# check if we already have the file in our cache
|
||||
cached_file, cached_size = self._get_cache_file(remote_url)
|
||||
if cached_size is not None:
|
||||
|
||||
@ -263,7 +263,8 @@ class StorageHelper(object):
|
||||
log.error(str(ex))
|
||||
return None
|
||||
except Exception as ex:
|
||||
log.error("Failed credentials for {}. Reason: {}".format(base_url or url, ex))
|
||||
log.error("Failed creating storage object {} Reason: {}".format(
|
||||
base_url or url, ex))
|
||||
return None
|
||||
|
||||
cls._helpers[instance_key] = instance
|
||||
@ -361,14 +362,16 @@ class StorageHelper(object):
|
||||
# url2pathname is specifically intended to operate on (urlparse result).path
|
||||
# and returns a cross-platform compatible result
|
||||
driver_uri = url2pathname(url)
|
||||
if Path(driver_uri).is_file():
|
||||
driver_uri = str(Path(driver_uri).parent)
|
||||
elif not Path(driver_uri).exists():
|
||||
# assume a folder and create
|
||||
Path(driver_uri).mkdir(parents=True, exist_ok=True)
|
||||
path_driver_uri = Path(driver_uri)
|
||||
# if path_driver_uri.is_file():
|
||||
# driver_uri = str(path_driver_uri.parent)
|
||||
# elif not path_driver_uri.exists():
|
||||
# # assume a folder and create
|
||||
# # Path(driver_uri).mkdir(parents=True, exist_ok=True)
|
||||
# pass
|
||||
|
||||
self._driver = _FileStorageDriver(driver_uri)
|
||||
self._container = self._driver.get_container(container_name='.')
|
||||
self._driver = _FileStorageDriver(str(path_driver_uri.root))
|
||||
self._container = None
|
||||
|
||||
@classmethod
|
||||
def terminate_uploads(cls, force=True, timeout=2.0):
|
||||
@ -910,8 +913,8 @@ class StorageHelper(object):
|
||||
except Exception as e:
|
||||
self._log.error("Calling upload callback when starting upload: %s" % str(e))
|
||||
if verbose:
|
||||
msg = 'Starting upload: {} => {}{}'.format(src_path, self._container.name if self._container else '',
|
||||
object_name)
|
||||
msg = 'Starting upload: {} => {}{}'.format(
|
||||
src_path, self._container.name if self._container else '', object_name)
|
||||
if object_name.startswith('file://') or object_name.startswith('/'):
|
||||
self._log.debug(msg)
|
||||
else:
|
||||
@ -946,7 +949,8 @@ class StorageHelper(object):
|
||||
def _get_object(self, path):
|
||||
object_name = self._normalize_object_name(path)
|
||||
try:
|
||||
return self._driver.get_object(container_name=self._container.name, object_name=object_name)
|
||||
return self._driver.get_object(
|
||||
container_name=self._container.name if self._container else '', object_name=object_name)
|
||||
except ConnectionError as ex:
|
||||
raise DownloadError
|
||||
except Exception as e:
|
||||
@ -1757,9 +1761,6 @@ class _FileStorageDriver(_Driver):
|
||||
# Use the key as the path to the storage
|
||||
self.base_path = key
|
||||
|
||||
if not os.path.isdir(self.base_path):
|
||||
raise ValueError('The base path is not a directory')
|
||||
|
||||
def _make_path(self, path, ignore_existing=True):
|
||||
"""
|
||||
Create a path by checking if it already exists
|
||||
@ -1781,7 +1782,7 @@ class _FileStorageDriver(_Driver):
|
||||
"""
|
||||
|
||||
if '/' in container_name or '\\' in container_name:
|
||||
raise ValueError(value=None, driver=self, container_name=container_name)
|
||||
raise ValueError("Container name \"{}\" cannot contain \\ or / ".format(container_name))
|
||||
|
||||
def _make_container(self, container_name):
|
||||
"""
|
||||
@ -1793,7 +1794,7 @@ class _FileStorageDriver(_Driver):
|
||||
:return: Container instance.
|
||||
:rtype: :class:`Container`
|
||||
"""
|
||||
|
||||
container_name = container_name or '.'
|
||||
self._check_container_name(container_name)
|
||||
|
||||
full_path = os.path.realpath(os.path.join(self.base_path, container_name))
|
||||
@ -1801,14 +1802,15 @@ class _FileStorageDriver(_Driver):
|
||||
try:
|
||||
stat = os.stat(full_path)
|
||||
if not os.path.isdir(full_path):
|
||||
raise OSError('Target path is not a directory')
|
||||
raise OSError("Target path \"{}\" is not a directory".format(full_path))
|
||||
except OSError:
|
||||
raise ValueError(value=None, driver=self, container_name=container_name)
|
||||
raise OSError("Target path \"{}\" is not accessible or does not exist".format(full_path))
|
||||
|
||||
extra = {}
|
||||
extra['creation_time'] = stat.st_ctime
|
||||
extra['access_time'] = stat.st_atime
|
||||
extra['modify_time'] = stat.st_mtime
|
||||
extra = {
|
||||
'creation_time': stat.st_ctime,
|
||||
'access_time': stat.st_atime,
|
||||
'modify_time': stat.st_mtime,
|
||||
}
|
||||
|
||||
return self._Container(name=container_name, extra=extra, driver=self)
|
||||
|
||||
@ -1826,20 +1828,21 @@ class _FileStorageDriver(_Driver):
|
||||
:rtype: :class:`Object`
|
||||
"""
|
||||
|
||||
full_path = os.path.realpath(os.path.join(self.base_path, container.name, object_name))
|
||||
full_path = os.path.realpath(os.path.join(self.base_path, container.name if container else '.', object_name))
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
raise ValueError(value=None, driver=self, object_name=object_name)
|
||||
raise ValueError("Target path \"{}\" already exist".format(full_path))
|
||||
|
||||
try:
|
||||
stat = os.stat(full_path)
|
||||
except Exception:
|
||||
raise ValueError(value=None, driver=self, object_name=object_name)
|
||||
raise ValueError("Cannot access target path \"{}\"".format(full_path))
|
||||
|
||||
extra = {}
|
||||
extra['creation_time'] = stat.st_ctime
|
||||
extra['access_time'] = stat.st_atime
|
||||
extra['modify_time'] = stat.st_mtime
|
||||
extra = {
|
||||
'creation_time': stat.st_ctime,
|
||||
'access_time': stat.st_atime,
|
||||
'modify_time': stat.st_mtime,
|
||||
}
|
||||
|
||||
return self.Object(name=object_name, size=stat.st_size, extra=extra,
|
||||
driver=self, container=container, hash=None, meta_data=None)
|
||||
@ -1914,10 +1917,10 @@ class _FileStorageDriver(_Driver):
|
||||
:return: A CDN URL for this container.
|
||||
:rtype: ``str``
|
||||
"""
|
||||
path = os.path.realpath(os.path.join(self.base_path, container.name))
|
||||
path = os.path.realpath(os.path.join(self.base_path, container.name if container else '.'))
|
||||
|
||||
if check and not os.path.isdir(path):
|
||||
raise ValueError(value=None, driver=self, container_name=container.name)
|
||||
raise ValueError("Target path \"{}\" does not exist".format(path))
|
||||
|
||||
return path
|
||||
|
||||
@ -1977,9 +1980,7 @@ class _FileStorageDriver(_Driver):
|
||||
base_name = os.path.basename(destination_path)
|
||||
|
||||
if not base_name and not os.path.exists(destination_path):
|
||||
raise ValueError(
|
||||
value='Path %s does not exist' % (destination_path),
|
||||
driver=self)
|
||||
raise ValueError('Path \"%s\" does not exist'.format(destination_path))
|
||||
|
||||
if not base_name:
|
||||
file_path = os.path.join(destination_path, obj.name)
|
||||
@ -1987,7 +1988,7 @@ class _FileStorageDriver(_Driver):
|
||||
file_path = destination_path
|
||||
|
||||
if os.path.exists(file_path) and not overwrite_existing:
|
||||
raise ValueError('File %s already exists, but ' % (file_path) + 'overwrite_existing=False')
|
||||
raise ValueError('File \"{}\" already exists, but overwrite_existing=False'.format(file_path))
|
||||
|
||||
try:
|
||||
shutil.copy(obj_path, file_path)
|
||||
@ -2146,7 +2147,7 @@ class _FileStorageDriver(_Driver):
|
||||
:return: :class:`Container` instance on success.
|
||||
:rtype: :class:`Container`
|
||||
"""
|
||||
|
||||
container_name = container_name or '.'
|
||||
self._check_container_name(container_name)
|
||||
|
||||
path = os.path.join(self.base_path, container_name)
|
||||
@ -2156,13 +2157,13 @@ class _FileStorageDriver(_Driver):
|
||||
except OSError:
|
||||
exp = sys.exc_info()[1]
|
||||
if exp.errno == errno.EEXIST:
|
||||
raise ValueError('Container %s with this name already exists. The name '
|
||||
raise ValueError('Container \"{}\" with this name already exists. The name '
|
||||
'must be unique among all the containers in the '
|
||||
'system' % container_name)
|
||||
'system'.format(container_name))
|
||||
else:
|
||||
raise ValueError( 'Error creating container %s' % container_name)
|
||||
raise ValueError('Error creating container \"{}\"'.format(container_name))
|
||||
except Exception:
|
||||
raise ValueError('Error creating container %s' % container_name)
|
||||
raise ValueError('Error creating container \"{}\"'.format(container_name))
|
||||
|
||||
return self._make_container(container_name)
|
||||
|
||||
@ -2179,7 +2180,7 @@ class _FileStorageDriver(_Driver):
|
||||
|
||||
# Check if there are any objects inside this
|
||||
for obj in self._get_objects(container):
|
||||
raise ValueError(value='Container %s is not empty' % container.name)
|
||||
raise ValueError('Container \"%s\" is not empty'.format(container.name))
|
||||
|
||||
path = self.get_container_cdn_url(container, check=True)
|
||||
|
||||
@ -2259,7 +2260,10 @@ class _FileStorageDriver(_Driver):
|
||||
# this will always make sure we have full path and file:// prefix
|
||||
full_url = StorageHelper.conform_url(remote_path)
|
||||
# now get rid of the file:// prefix
|
||||
return Path(full_url[7:]).as_posix()
|
||||
path = Path(full_url[7:])
|
||||
if not path.exists():
|
||||
raise ValueError("Requested path does not exist: {}".format(path))
|
||||
return path.as_posix()
|
||||
|
||||
def test_upload(self, test_path, config, **kwargs):
|
||||
return True
|
||||
|
||||
@ -9,13 +9,23 @@ def get_config_object_matcher(**patterns):
|
||||
raise ValueError('Unsupported object matcher (expecting string): %s'
|
||||
% ', '.join('%s=%s' % (k, v) for k, v in unsupported.items()))
|
||||
|
||||
# optimize simple patters
|
||||
starts_with = {k: v.rstrip('*') for k, v in patterns.items() if '*' not in v.rstrip('*') and '?' not in v}
|
||||
patterns = {k: v for k, v in patterns.items() if v not in starts_with}
|
||||
|
||||
def _matcher(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if not value:
|
||||
continue
|
||||
pat = patterns.get(key)
|
||||
if pat and fnmatch.fnmatch(value, pat):
|
||||
return True
|
||||
start = starts_with.get(key)
|
||||
if start:
|
||||
if value.startswith(start):
|
||||
return True
|
||||
else:
|
||||
pat = patterns.get(key)
|
||||
if pat and fnmatch.fnmatch(value, pat):
|
||||
return True
|
||||
|
||||
return _matcher
|
||||
|
||||
|
||||
|
||||
@ -156,8 +156,6 @@ class Task(_Task):
|
||||
object.
|
||||
|
||||
:return: The current running Task (experiment).
|
||||
|
||||
:rtype: Task() object or ``None``
|
||||
"""
|
||||
return cls.__main_task
|
||||
|
||||
@ -285,8 +283,6 @@ class Task(_Task):
|
||||
- ``False`` - Do not automatically create.
|
||||
|
||||
:return: The main execution Task (Task context).
|
||||
|
||||
:rtype: Task object
|
||||
"""
|
||||
|
||||
def verify_defaults_match():
|
||||
@ -509,8 +505,6 @@ class Task(_Task):
|
||||
|
||||
:type task_type: TaskTypeEnum(value)
|
||||
:return: A new experiment.
|
||||
|
||||
:rtype: Task() object
|
||||
"""
|
||||
if not project_name:
|
||||
if not cls.__main_task:
|
||||
@ -543,8 +537,6 @@ class Task(_Task):
|
||||
:param str task_name: The name of the Task within ``project_name`` to get.
|
||||
|
||||
:return: The Task specified by Id, or project name / experiment name combination.
|
||||
|
||||
:rtype: Task object
|
||||
"""
|
||||
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
|
||||
|
||||
@ -574,8 +566,6 @@ class Task(_Task):
|
||||
If None is passed, returns all tasks within the project
|
||||
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
|
||||
:return: The Tasks specified by the parameter combinations (see the parameters).
|
||||
|
||||
:rtype: List of Task objects
|
||||
"""
|
||||
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {}))
|
||||
|
||||
@ -604,8 +594,6 @@ class Task(_Task):
|
||||
A read-only dictionary of Task artifacts (name, artifact).
|
||||
|
||||
:return: The artifacts.
|
||||
|
||||
:rtype: dict
|
||||
"""
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
return ReadOnlyDict()
|
||||
@ -649,8 +637,6 @@ class Task(_Task):
|
||||
If ``None``, the new task inherits the original Task's project. (Optional)
|
||||
|
||||
:return: The new cloned Task (experiment).
|
||||
|
||||
:rtype: Task object
|
||||
"""
|
||||
assert isinstance(source_task, (six.string_types, Task))
|
||||
if not Session.check_min_api_version('2.4'):
|
||||
@ -685,7 +671,7 @@ class Task(_Task):
|
||||
:param str queue_name: The name of the queue. If not specified, then ``queue_id`` must be specified.
|
||||
:param str queue_id: The Id of the queue. If not specified, then ``queue_name`` must be specified.
|
||||
|
||||
:return: An enqueue response.
|
||||
:return: An enqueue JSON response.
|
||||
|
||||
.. code-block:: javascript
|
||||
|
||||
@ -713,8 +699,6 @@ class Task(_Task):
|
||||
- ``last_update`` - The last Task update time, including Task creation, update, change, or events for
|
||||
this task (ISO 8601 format).
|
||||
- ``execution.queue`` - The Id of the queue where the Task is enqueued. ``null`` indicates not enqueued.
|
||||
|
||||
:rtype: JSON
|
||||
"""
|
||||
assert isinstance(task, (six.string_types, Task))
|
||||
if not Session.check_min_api_version('2.4'):
|
||||
@ -746,7 +730,7 @@ class Task(_Task):
|
||||
:param task: The Task to dequeue. Specify a Task object or Task Id.
|
||||
:type task: Task object / str
|
||||
|
||||
:return: A dequeue response.
|
||||
:return: A dequeue JSON response.
|
||||
|
||||
.. code-block:: javascript
|
||||
|
||||
@ -774,8 +758,6 @@ class Task(_Task):
|
||||
- ``execution.queue`` - The Id of the queue where the Task is enqueued. ``null`` indicates not enqueued.
|
||||
|
||||
- ``updated`` - The number of Tasks updated (an integer or ``null``).
|
||||
|
||||
:rtype: JSON
|
||||
"""
|
||||
assert isinstance(task, (six.string_types, Task))
|
||||
if not Session.check_min_api_version('2.4'):
|
||||
@ -871,10 +853,8 @@ class Task(_Task):
|
||||
|
||||
:type configuration: dict, pathlib.Path/str
|
||||
|
||||
:return: If a dictictonary is specified, then a dictionary is returned. If pathlib2.Path / string is
|
||||
specified, then a path to a local configuration file is returned.
|
||||
|
||||
:rtype: Configuration object
|
||||
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
|
||||
specified, then a path to a local configuration file is returned. Configuration object.
|
||||
"""
|
||||
if not isinstance(configuration, (dict, Path, six.string_types)):
|
||||
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
|
||||
@ -936,9 +916,7 @@ class Task(_Task):
|
||||
'person': 1
|
||||
}
|
||||
|
||||
:return: The label enumeration dictionary.
|
||||
|
||||
:rtype: JSON
|
||||
:return: The label enumeration dictionary (JSON).
|
||||
"""
|
||||
if not isinstance(enumeration, dict):
|
||||
raise ValueError("connect_label_enumeration supports only `dict` type, "
|
||||
@ -960,8 +938,6 @@ class Task(_Task):
|
||||
**Trains Web-App (UI)**.
|
||||
|
||||
:return: The Logger for the Task (experiment).
|
||||
|
||||
:rtype: Logger object
|
||||
"""
|
||||
return self._get_logger()
|
||||
|
||||
@ -1106,8 +1082,6 @@ class Task(_Task):
|
||||
After calling ``get_registered_artifacts``, you can still modify the registered artifacts.
|
||||
|
||||
:return: The registered (dynamically synchronized) artifacts.
|
||||
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._artifacts_manager.registered_artifacts
|
||||
|
||||
@ -1149,8 +1123,6 @@ class Task(_Task):
|
||||
- ``True`` - Upload succeeded.
|
||||
- ``False`` - Upload failed.
|
||||
|
||||
:rtype: bool
|
||||
|
||||
:raise: If the artifact object type is not supported, raise a ``ValueError``.
|
||||
"""
|
||||
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
|
||||
@ -1183,8 +1155,6 @@ class Task(_Task):
|
||||
|
||||
- ``True`` - Is the main execution Task.
|
||||
- ``False`` - Is not the main execution Task.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self.is_main_task()
|
||||
|
||||
@ -1205,8 +1175,6 @@ class Task(_Task):
|
||||
|
||||
- ``True`` - Is the main execution Task.
|
||||
- ``False`` - Is not the main execution Task.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self is self.__main_task
|
||||
|
||||
@ -1263,8 +1231,6 @@ class Task(_Task):
|
||||
sends a request to the **Trains Server** (backend).
|
||||
|
||||
:return: The last reported iteration number.
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
self._reload_last_iteration()
|
||||
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
|
||||
@ -1320,8 +1286,6 @@ class Task(_Task):
|
||||
}
|
||||
|
||||
:return: The last scalar metrics.
|
||||
|
||||
:rtype: dict
|
||||
"""
|
||||
self.reload()
|
||||
metrics = self.data.last_metrics
|
||||
@ -1339,8 +1303,6 @@ class Task(_Task):
|
||||
|
||||
.. note::
|
||||
The values are not parsed. They are returned as is.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
return naive_nested_from_flat_dictionary(self.get_parameters())
|
||||
|
||||
|
||||
@ -27,10 +27,13 @@ class CheckPackageUpdates(object):
|
||||
cls._package_version_checked = True
|
||||
client, version = Session._client[0]
|
||||
version = Version(version)
|
||||
is_demo = 'https://demoapi.trains.allegro.ai/'.startswith(Session.get_api_server_host())
|
||||
|
||||
update_server_releases = requests.get(
|
||||
'https://updates.trains.allegro.ai/updates',
|
||||
json={"versions": {c: str(v) for c, v in Session._client}},
|
||||
json={"demo": is_demo,
|
||||
"versions": {c: str(v) for c, v in Session._client},
|
||||
"CI": str(os.environ.get('CI', ''))},
|
||||
timeout=3.0
|
||||
)
|
||||
|
||||
@ -43,6 +46,10 @@ class CheckPackageUpdates(object):
|
||||
if "version" not in client_answer:
|
||||
return None
|
||||
|
||||
# do not output upgrade message if we are running inside a CI process.
|
||||
if EnvEntry("CI", type=bool, ignore_errors=True).get():
|
||||
return None
|
||||
|
||||
latest_version = Version(client_answer["version"])
|
||||
|
||||
if version >= latest_version:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user