Fix documentation and layout (PEP8)

This commit is contained in:
allegroai
2020-05-24 08:16:12 +03:00
parent 70cdb13252
commit 96f899d028
43 changed files with 580 additions and 322 deletions

View File

@@ -1,6 +1,9 @@
from time import sleep, time
from ..parameters import DiscreteParameterRange, UniformParameterRange, RandomSeed, UniformIntegerParameterRange
from typing import Any, Optional, Sequence
from ..optimization import Objective, SearchStrategy
from ..parameters import (
DiscreteParameterRange, UniformParameterRange, RandomSeed, UniformIntegerParameterRange, Parameter, )
from ...task import Task
try:
@@ -9,16 +12,27 @@ try:
import hpbandster.core.nameserver as hpns
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
Task.add_requirements('hpbandster')
except ImportError:
raise ValueError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
"install with: pip install hpbandster")
class TrainsBandsterWorker(Worker):
def __init__(self, *args, optimizer, base_task_id, queue_name, objective,
sleep_interval=0, budget_iteration_scale=1., **kwargs):
super(TrainsBandsterWorker, self).__init__(*args, **kwargs)
class _TrainsBandsterWorker(Worker):
def __init__(
self,
*args, # type: Any
optimizer, # type: OptimizerBOHB
base_task_id, # type: str
queue_name, # type: str
objective, # type: Objective
sleep_interval=0, # type: float
budget_iteration_scale=1., # type: float
**kwargs # type: Any
):
# type: (...) -> _TrainsBandsterWorker
super(_TrainsBandsterWorker, self).__init__(*args, **kwargs)
self.optimizer = optimizer
self.base_task_id = base_task_id
self.queue_name = queue_name
@@ -28,6 +42,7 @@ class TrainsBandsterWorker(Worker):
self._current_job = None
def compute(self, config, budget, **kwargs):
# type: (dict, float, Any) -> dict
"""
Simple example for a compute function
The loss is just a the config + some noise (that decreases with the budget)
@@ -43,10 +58,12 @@ class TrainsBandsterWorker(Worker):
'info' (dict)
"""
self._current_job = self.optimizer.helper_create_job(self.base_task_id, parameter_override=config)
# noinspection PyProtectedMember
self.optimizer._current_jobs.append(self._current_job)
self._current_job.launch(self.queue_name)
iteration_value = None
while not self._current_job.is_stopped():
# noinspection PyProtectedMember
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(self._current_job)
if iteration_value and iteration_value[0] >= self.budget_iteration_scale * budget:
self._current_job.abort()
@@ -56,33 +73,65 @@ class TrainsBandsterWorker(Worker):
result = {
# this is the a mandatory field to run hyperband
# remember: HpBandSter always minimizes!
'loss': float(self.objective.get_normalized_objective(self._current_job)*-1.),
'loss': float(self.objective.get_normalized_objective(self._current_job) * -1.),
# can be used for any user-defined information - also mandatory
'info': self._current_job.task_id()
}
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value))
# noinspection PyProtectedMember
self.optimizer._current_jobs.remove(self._current_job)
return result
class OptimizerBOHB(SearchStrategy, RandomSeed):
def __init__(self, base_task_id, hyper_parameters, objective_metric,
execution_queue, num_concurrent_workers, min_iteration_per_job, max_iteration_per_job, total_max_jobs,
pool_period_min=2.0, max_job_execution_minutes=None, **bohb_kargs):
def __init__(
self,
base_task_id, # type: str
hyper_parameters, # type: Sequence[Parameter]
objective_metric, # type: Objective
execution_queue, # type: str
num_concurrent_workers, # type: int
total_max_jobs, # type: Optional[int]
min_iteration_per_job, # type: Optional[int]
max_iteration_per_job, # type: Optional[int]
pool_period_min=2., # type: float
max_job_execution_minutes=None, # type: Optional[float]
**bohb_kwargs, # type: Any
):
# type: (...) -> OptimizerBOHB
"""
Initialize a search strategy optimizer
Initialize a BOHB search strategy optimizer
BOHB performs robust and efficient hyperparameter optimization at scale by combining
the speed of Hyperband searches with the guidance and guarantees of convergence of Bayesian
Optimization. Instead of sampling new configurations at random,
BOHB uses kernel density estimators to select promising candidates.
For reference: ::
@InProceedings{falkner-icml-18,
title = {{BOHB}: Robust and Efficient Hyperparameter Optimization at Scale},
author = {Falkner, Stefan and Klein, Aaron and Hutter, Frank},
booktitle = {Proceedings of the 35th International Conference on Machine Learning},
pages = {1436--1445},
year = {2018},
}
:param str base_task_id: Task ID (str)
:param list hyper_parameters: list of Parameter objects to optimize over
:param Objective objective_metric: Objective metric to maximize / minimize
:param str execution_queue: execution queue to use for launching Tasks (experiments).
:param int num_concurrent_workers: Limit number of concurrent running machines
:param float min_iteration_per_job: minimum number of iterations for a job to run.
:param int num_concurrent_workers: Limit number of concurrent running Tasks (machines)
:param int min_iteration_per_job: minimum number of iterations for a job to run.
'iterations' are the reported iterations for the specified objective,
not the maximum reported iteration of the Task.
:param int max_iteration_per_job: number of iteration per job
:param int total_max_jobs: total maximum job for the optimization process. Default None, unlimited
'iterations' are the reported iterations for the specified objective,
not the maximum reported iteration of the Task.
:param int total_max_jobs: total maximum job for the optimization process.
Must be provided in order to calculate the total budget for the optimization process.
:param float pool_period_min: time in minutes between two consecutive pools
:param float max_job_execution_minutes: maximum time per single job in minutes, if exceeded job is aborted
:param ** bohb_kargs: arguments passed directly yo the BOHB object
:param bohb_kwargs: arguments passed directly yo the BOHB object
"""
super(OptimizerBOHB, self).__init__(
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
@@ -91,16 +140,25 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
total_max_jobs=total_max_jobs)
self._max_iteration_per_job = max_iteration_per_job
self._min_iteration_per_job = min_iteration_per_job
self._bohb_kwargs = bohb_kargs or {}
self._bohb_kwargs = bohb_kwargs or {}
self._param_iterator = None
self._namespace = None
self._bohb = None
self._res = None
def set_optimization_args(self, eta=3, min_budget=None, max_budget=None,
min_points_in_model=None, top_n_percent=15,
num_samples=None, random_fraction=1/3., bandwidth_factor=3,
min_bandwidth=1e-3):
def set_optimization_args(
self,
eta=3, # type: float
min_budget=None, # type: Optional[float]
max_budget=None, # type: Optional[float]
min_points_in_model=None, # type: Optional[int]
top_n_percent=15, # type: Optional[int]
num_samples=None, # type: Optional[int]
random_fraction=1 / 3., # type: Optional[float]
bandwidth_factor=3, # type: Optional[float]
min_bandwidth=1e-3, # type: Optional[float]
):
# type: (...) -> ()
"""
Defaults copied from BOHB constructor, see details in BOHB.__init__
@@ -134,7 +192,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
max_budget : float (1)
The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
:math:`a^2 + b^2 = c^2 /sim /eta^k` for :math:`k/in [0, 1, ... , num/_subsets - 1]`.
min_points_in_model: int (None)
number of observations to start building a KDE. Default 'None' means
dim+1, the bare minimum.
@@ -166,9 +224,16 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
self._bohb_kwargs['min_bandwidth'] = min_bandwidth
def start(self):
# Step 1: Start a nameserver
# type: () -> ()
"""
Start the Optimizer controller function loop()
If the calling process is stopped, the controller will stop as well.
Notice: This function returns only after optimization is completed! or stop() was called.
"""
# Step 1: Start a NameServer
fake_run_id = 'OptimizerBOHB_{}'.format(time())
self._namespace = hpns.NameServer(run_id=fake_run_id, host='127.0.0.1', port=None)
self._namespace = hpns.NameServer(run_id=fake_run_id, host='127.0.0.1', port=0)
self._namespace.start()
# we have to scale the budget to the iterations per job, otherwise numbers might be too high
@@ -177,21 +242,22 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
# Step 2: Start the workers
workers = []
for i in range(self._num_concurrent_workers):
w = TrainsBandsterWorker(optimizer=self,
sleep_interval=int(self.pool_period_minutes*60),
budget_iteration_scale=budget_iteration_scale,
base_task_id=self._base_task_id,
objective=self._objective_metric,
queue_name=self._execution_queue,
nameserver='127.0.0.1', run_id=fake_run_id, id=i)
w = _TrainsBandsterWorker(
optimizer=self,
sleep_interval=int(self.pool_period_minutes * 60),
budget_iteration_scale=budget_iteration_scale,
base_task_id=self._base_task_id,
objective=self._objective_metric,
queue_name=self._execution_queue,
nameserver='127.0.0.1', run_id=fake_run_id, id=i)
w.run(background=True)
workers.append(w)
# Step 3: Run an optimizer
self._bohb = BOHB(configspace=self.convert_hyper_parameters_to_cs(),
self._bohb = BOHB(configspace=self._convert_hyper_parameters_to_cs(),
run_id=fake_run_id,
num_samples=self.total_max_jobs,
min_budget=float(self._min_iteration_per_job)/float(self._max_iteration_per_job),
min_budget=float(self._min_iteration_per_job) / float(self._max_iteration_per_job),
**self._bohb_kwargs)
self._res = self._bohb.run(n_iterations=self.total_max_jobs, min_n_workers=self._num_concurrent_workers)
@@ -199,6 +265,11 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
self.stop()
def stop(self):
# type: () -> ()
"""
Stop the current running optimization loop,
Called from a different thread than the start()
"""
# After the optimizer run, we must shutdown the master and the nameserver.
self._bohb.shutdown(shutdown_workers=True)
self._namespace.shutdown()
@@ -216,13 +287,14 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
print('A total of {} unique configurations where sampled.'.format(len(id2config.keys())))
print('A total of {} runs where executed.'.format(len(self._res.get_all_runs())))
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
print('The run took {:.1f} seconds to complete.'.format(
all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))
all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))
def convert_hyper_parameters_to_cs(self):
def _convert_hyper_parameters_to_cs(self):
# type: () -> CS.ConfigurationSpace
cs = CS.ConfigurationSpace(seed=self._seed)
for p in self._hyper_parameters:
if isinstance(p, UniformParameterRange):

View File

@@ -2,16 +2,26 @@ import hashlib
from datetime import datetime
from logging import getLogger
from time import time, sleep
from typing import Optional, Mapping, Sequence, Any
from ..task import Task
from ..backend_api.services import tasks as tasks_service
from ..backend_api.services import events as events_service
logger = getLogger('trains.automation.job')
class TrainsJob(object):
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, parent=None, **kwargs):
def __init__(
self,
base_task_id, # type: str
parameter_override=None, # type: Optional[Mapping[str, str]]
task_overrides=None, # type: Optional[Mapping[str, str]]
tags=None, # type: Optional[Sequence[str]]
parent=None, # type: Optional[str]
**kwargs # type: Any
):
# type: (...) -> TrainsJob
"""
Create a new Task based in a base_task_id with a different set of parameters
@@ -32,11 +42,12 @@ class TrainsJob(object):
if task_overrides:
# todo: make sure it works
# noinspection PyProtectedMember
self.task._edit(task_overrides)
self.task._edit(**task_overrides)
self.task_started = False
self._worker = None
def get_metric(self, title, series):
# type: (str, str) -> (float, float, float)
"""
Retrieve a specific scalar metric from the running Task.
@@ -63,6 +74,7 @@ class TrainsJob(object):
return tuple(response.response_data['tasks'][0]['last_metrics'][title][series][v] for v in values)
def launch(self, queue_name=None):
# type: (str) -> ()
"""
Send Job for execution on the requested execution queue
@@ -74,6 +86,7 @@ class TrainsJob(object):
logger.warning(ex)
def abort(self):
# type: () -> ()
"""
Abort currently running job (can be called multiple times)
"""
@@ -83,6 +96,7 @@ class TrainsJob(object):
logger.warning(ex)
def elapsed(self):
# type: () -> float
"""
Return the time in seconds since job started. Return -1 if job is still pending
@@ -94,6 +108,7 @@ class TrainsJob(object):
return (datetime.now() - self.task.data.started).timestamp()
def iterations(self):
# type: () -> int
"""
Return the last iteration value of the current job. -1 if job has not started yet
@@ -105,6 +120,7 @@ class TrainsJob(object):
return self.task.get_last_iteration()
def task_id(self):
# type: () -> str
"""
Return the Task id.
@@ -113,6 +129,7 @@ class TrainsJob(object):
return self.task.id
def status(self):
# type: () -> Task.TaskStatusEnum
"""
Return the Job Task current status, see Task.TaskStatusEnum
@@ -121,6 +138,7 @@ class TrainsJob(object):
return self.task.status
def wait(self, timeout=None, pool_period=30.):
# type: (Optional[float], float) -> bool
"""
Wait until the task is fully executed (i.e. aborted/completed/failed)
@@ -129,7 +147,7 @@ class TrainsJob(object):
:return bool: Return True is Task finished.
"""
tic = time()
while timeout is None or time()-tic < timeout*60.:
while timeout is None or time() - tic < timeout * 60.:
if self.is_stopped():
return True
sleep(pool_period)
@@ -137,6 +155,7 @@ class TrainsJob(object):
return self.is_stopped()
def get_console_output(self, number_of_reports=1):
# type: (int) -> Sequence[str]
"""
Return a list of console outputs reported by the Task.
Returned console outputs are retrieved from the most updated console outputs.
@@ -147,6 +166,7 @@ class TrainsJob(object):
return self.task.get_reported_console_output(number_of_reports=number_of_reports)
def worker(self):
# type: () -> str
"""
Return the current worker id executing this Job. If job is pending, returns None
@@ -165,6 +185,7 @@ class TrainsJob(object):
return self._worker
def is_running(self):
# type: () -> bool
"""
Return True if job is currently running (pending is considered False)
@@ -173,6 +194,7 @@ class TrainsJob(object):
return self.task.status == Task.TaskStatusEnum.in_progress
def is_stopped(self):
# type: () -> bool
"""
Return True if job is has executed and is not any more
@@ -183,6 +205,7 @@ class TrainsJob(object):
Task.TaskStatusEnum.failed, Task.TaskStatusEnum.published)
def is_pending(self):
# type: () -> bool
"""
Return True if job is waiting for execution
@@ -191,8 +214,20 @@ class TrainsJob(object):
return self.task.status in (Task.TaskStatusEnum.queued, Task.TaskStatusEnum.created)
class JobStub(object):
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, **_):
# noinspection PyMethodMayBeStatic, PyUnusedLocal
class _JobStub(object):
"""
This is a Job Stub, use only for debugging
"""
def __init__(
self,
base_task_id, # type: str
parameter_override=None, # type: Optional[Mapping[str, str]]
task_overrides=None, # type: Optional[Mapping[str, str]]
tags=None, # type: Optional[Sequence[str]]
**kwargs # type: Any
):
# type: (...) -> _JobStub
self.task = None
self.base_task_id = base_task_id
self.parameter_override = parameter_override
@@ -202,14 +237,17 @@ class JobStub(object):
self.task_started = None
def launch(self, queue_name=None):
# type: (str) -> ()
self.iteration = 0
self.task_started = time()
print('launching', self.parameter_override, 'in', queue_name)
def abort(self):
# type: () -> ()
self.task_started = -1
def elapsed(self):
# type: () -> float
"""
Return the time in seconds since job started. Return -1 if job is still pending
@@ -220,6 +258,7 @@ class JobStub(object):
return time() - self.task_started
def iterations(self):
# type: () -> int
"""
Return the last iteration value of the current job. -1 if job has not started yet
:return int: Task last iteration
@@ -229,6 +268,7 @@ class JobStub(object):
return self.iteration
def get_metric(self, title, series):
# type: (str, str) -> (float, float, float)
"""
Retrieve a specific scalar metric from the running Task.
@@ -239,15 +279,19 @@ class JobStub(object):
return 0, 1.0, 0.123
def task_id(self):
# type: () -> str
return 'stub'
def worker(self):
# type: () -> ()
return None
def status(self):
# type: () -> str
return 'in_progress'
def wait(self, timeout=None, pool_period=30.):
# type: (Optional[float], float) -> bool
"""
Wait for the task to be processed (i.e. aborted/completed/failed)
@@ -258,6 +302,7 @@ class JobStub(object):
return True
def get_console_output(self, number_of_reports=1):
# type: (int) -> Sequence[str]
"""
Return a list of console outputs reported by the Task.
Returned console outputs are retrieved from the most updated console outputs.
@@ -269,11 +314,13 @@ class JobStub(object):
return []
def is_running(self):
# type: () -> bool
return self.task_started is not None and self.task_started > 0
def is_stopped(self):
# type: () -> bool
return self.task_started is not None and self.task_started < 0
def is_pending(self):
# type: () -> bool
return self.task_started is None

View File

@@ -6,10 +6,11 @@ from itertools import product
from logging import getLogger
from threading import Thread, Event
from time import time
from typing import Union, Any, Sequence, Optional, Mapping, Callable
from .job import TrainsJob
from ..task import Task
from .parameters import Parameter
from ..task import Task
logger = getLogger('trains.automation.optimization')
@@ -32,6 +33,7 @@ class Objective(object):
"""
def __init__(self, title, series, order='max', extremum=False):
# type: (str, str, Union['max', 'min'], bool) -> Objective
"""
Construct objective object that will return the scalar value for a specific task ID
@@ -45,12 +47,12 @@ class Objective(object):
self.series = series
assert order in ('min', 'max',)
# normalize value so we always look for the highest objective value
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') or \
(not isinstance(order, str) and order < 0) else +1
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') else +1
self._metric = None
self.extremum = extremum
def get_objective(self, task_id):
# type: (Union[str, Task, TrainsJob]) -> Optional[float]
"""
Return a specific task scalar value based on the objective settings (title/series)
@@ -58,7 +60,7 @@ class Objective(object):
:return float: scalar value
"""
# create self._metric
self.get_last_metrics_encode_field()
self._get_last_metrics_encode_field()
if isinstance(task_id, Task):
task_id = task_id.id
@@ -85,6 +87,7 @@ class Objective(object):
return None
def get_current_raw_objective(self, task):
# type: (Union[TrainsJob, Task]) -> (int, float)
"""
Return the current raw value (without sign normalization) of the objective
@@ -103,12 +106,14 @@ class Objective(object):
# todo: replace with more efficient code
scalars = task.get_reported_scalars()
# noinspection PyBroadException
try:
return scalars[self.title][self.series]['x'][-1], scalars[self.title][self.series]['y'][-1]
except Exception:
return None
def get_objective_sign(self):
# type: () -> float
"""
Return the sign of the objective (i.e. +1 if maximizing, and -1 if minimizing)
@@ -117,6 +122,7 @@ class Objective(object):
return self.sign
def get_objective_metric(self):
# type: () -> (str, str)
"""
Return the metric title, series pair of the objective
@@ -125,6 +131,7 @@ class Objective(object):
return self.title, self.series
def get_normalized_objective(self, task_id):
# type: (Union[str, Task, TrainsJob]) -> Optional[float]
"""
Return a normalized task scalar value based on the objective settings (title/series)
I.e. objective is always to maximize the returned value
@@ -138,7 +145,13 @@ class Objective(object):
# normalize value so we always look for the highest objective value
return self.sign * objective
def get_last_metrics_encode_field(self):
def _get_last_metrics_encode_field(self):
# type: () -> str
"""
Return encoded representation of title/series metric
:return str: string representing the objective title/series
"""
if not self._metric:
title = hashlib.md5(str(self.title).encode('utf-8')).hexdigest()
series = hashlib.md5(str(self.series).encode('utf-8')).hexdigest()
@@ -153,11 +166,21 @@ class SearchStrategy(object):
Base Search strategy class, inherit to implement your custom strategy
"""
_tag = 'optimization'
_job_class = TrainsJob
_job_class = TrainsJob # type: TrainsJob
def __init__(self, base_task_id, hyper_parameters, objective_metric,
execution_queue, num_concurrent_workers, pool_period_min=2.,
max_job_execution_minutes=None, total_max_jobs=None, **_):
def __init__(
self,
base_task_id, # type: str
hyper_parameters, # type: Sequence[Parameter]
objective_metric, # type: Objective
execution_queue, # type: str
num_concurrent_workers, # type: int
pool_period_min=2., # type: float
max_job_execution_minutes=None, # type: Optional[float]
total_max_jobs=None, # type: Optional[int]
**_ # type: Any
):
# type: (...) -> SearchStrategy
"""
Initialize a search strategy optimizer
@@ -189,6 +212,7 @@ class SearchStrategy(object):
self._validate_base_task()
def start(self):
# type: () -> ()
"""
Start the Optimizer controller function loop()
If the calling process is stopped, the controller will stop as well.
@@ -205,6 +229,7 @@ class SearchStrategy(object):
counter += 1
def stop(self):
# type: () -> ()
"""
Stop the current running optimization loop,
Called from a different thread than the start()
@@ -212,6 +237,7 @@ class SearchStrategy(object):
self._stop_event.set()
def process_step(self):
# type: () -> bool
"""
Abstract helper function, not a must to implement, default use in start default implementation
Main optimization loop, called from the daemon thread created by start()
@@ -248,6 +274,7 @@ class SearchStrategy(object):
return bool(self._current_jobs)
def create_job(self):
# type: () -> Optional[TrainsJob]
"""
Abstract helper function, not a must to implement, default use in process_step default implementation
Create a new job if needed. return the newly created job.
@@ -258,17 +285,19 @@ class SearchStrategy(object):
return None
def monitor_job(self, job):
# type: (TrainsJob) -> bool
"""
Abstract helper function, not a must to implement, default use in process_step default implementation
Check if the job needs to be aborted or already completed
if return False, the job was aborted / completed, and should be taken off the current job list
:param TrainsJob job: a TrainsJob object to monitor
:return: boolean, If False, job is no longer relevant
:return bool: If False, job is no longer relevant
"""
return not job.is_stopped()
def get_running_jobs(self):
# type: () -> Sequence[TrainsJob]
"""
Return the current running TrainsJobs
@@ -277,28 +306,32 @@ class SearchStrategy(object):
return self._current_jobs
def get_created_jobs_ids(self):
# type: () -> Mapping[str, dict]
"""
Return a task ids dict created ny this optimizer until now, including completed and running jobs.
The values of the returned dict are the parameters used in the specific job
:return dict(str): dict of task ids (str) as keys, and their parameters dict as value
:return dict: dict of task ids (str) as keys, and their parameters dict as value
"""
return self._created_jobs_ids
def get_top_experiments(self, top_k):
# type: (int) -> Sequence[Task]
"""
Return a list of Tasks of the top performing experiments, based on the controller Objective object
:param int top_k: Number of Tasks (experiments) to return
:return list: List of Task objects, ordered by performance, where index 0 is the best performing Task.
"""
# metric_filter =
top_tasks = self._get_child_tasks(parent_task_id=self._job_parent_id or self._base_task_id,
order_by=self._objective_metric.get_last_metrics_encode_field(),
additional_filters={'page_size': int(top_k), 'page': 0})
# noinspection PyProtectedMember
top_tasks = self._get_child_tasks(
parent_task_id=self._job_parent_id or self._base_task_id,
order_by=self._objective_metric._get_last_metrics_encode_field(),
additional_filters={'page_size': int(top_k), 'page': 0})
return top_tasks
def get_objective_metric(self):
# type: () -> (str, str)
"""
Return the metric title, series pair of the objective
@@ -306,10 +339,19 @@ class SearchStrategy(object):
"""
return self._objective_metric.get_objective_metric()
def helper_create_job(self, base_task_id, parameter_override=None,
task_overrides=None, tags=None, parent=None, **kwargs):
def helper_create_job(
self,
base_task_id, # type: str
parameter_override=None, # type: Optional[Mapping[str, str]]
task_overrides=None, # type: Optional[Mapping[str, str]]
tags=None, # type: Optional[Sequence[str]]
parent=None, # type: Optional[str]
**kwargs # type: Any
):
# type: (...) -> TrainsJob
"""
Create a Job using the specified arguments, TrainsJob for details
:return TrainsJob: Returns a newly created Job instance
"""
if parameter_override:
@@ -334,6 +376,7 @@ class SearchStrategy(object):
return new_job
def set_job_class(self, job_class):
# type: (TrainsJob) -> ()
"""
Set the class to use for the helper_create_job function
@@ -342,6 +385,7 @@ class SearchStrategy(object):
self._job_class = job_class
def set_job_default_parent(self, job_parent_task_id):
# type: (str) -> ()
"""
Set the default parent for all Jobs created by the helper_create_job method
:param str job_parent_task_id: Parent task id
@@ -349,6 +393,7 @@ class SearchStrategy(object):
self._job_parent_id = job_parent_task_id
def set_job_naming_scheme(self, naming_function):
# type: (Optional[Callable[[str, dict], str]]) -> ()
"""
Set the function used to name a newly created job
@@ -357,6 +402,7 @@ class SearchStrategy(object):
self._naming_function = naming_function
def _validate_base_task(self):
# type: () -> ()
"""
Check the base task exists and contains the requested objective metric and hyper parameters
"""
@@ -378,6 +424,7 @@ class SearchStrategy(object):
self._objective_metric.get_objective_metric(), self._base_task_id))
def _get_task_project(self, parent_task_id):
# type: (str) -> (Optional[str])
if not parent_task_id:
return
if parent_task_id not in self._job_project:
@@ -387,7 +434,14 @@ class SearchStrategy(object):
return self._job_project.get(parent_task_id)
@classmethod
def _get_child_tasks(cls, parent_task_id, status=None, order_by=None, additional_filters=None):
def _get_child_tasks(
cls,
parent_task_id, # type: str
status=None, # type: Optional[Task.TaskStatusEnum]
order_by=None, # type: Optional[str]
additional_filters=None # type: Optional[dict]
):
# type: (...) -> (Sequence[Task])
"""
Helper function, return a list of tasks tagged automl with specific status ordered by sort_field
@@ -401,7 +455,7 @@ class SearchStrategy(object):
"execution.parameters.name"
"updated"
:param dict additional_filters: Additional task filters
:return List(Task): List of Task objects
:return list(Task): List of Task objects
"""
task_filter = {'parent': parent_task_id,
# 'tags': [cls._tag],
@@ -432,8 +486,19 @@ class GridSearch(SearchStrategy):
Full grid sampling of every hyper-parameter combination
"""
def __init__(self, base_task_id, hyper_parameters, objective_metric,
execution_queue, num_concurrent_workers, pool_period_min=2.0, max_job_execution_minutes=None, **_):
def __init__(
self,
base_task_id, # type: str
hyper_parameters, # type: Sequence[Parameter]
objective_metric, # type: Objective
execution_queue, # type: str
num_concurrent_workers, # type: int
pool_period_min=2., # type: float
max_job_execution_minutes=None, # type: Optional[float]
total_max_jobs=None, # type: Optional[int]
**_ # type: Any
):
# type: (...) -> GridSearch
"""
Initialize a grid search optimizer
@@ -444,14 +509,17 @@ class GridSearch(SearchStrategy):
:param int num_concurrent_workers: Limit number of concurrent running machines
:param float pool_period_min: time in minutes between two consecutive pools
:param float max_job_execution_minutes: maximum time per single job in minutes, if exceeded job is aborted
:param int total_max_jobs: total maximum job for the optimization process. Default None, unlimited
"""
super(GridSearch, self).__init__(
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min, max_job_execution_minutes=max_job_execution_minutes)
pool_period_min=pool_period_min, max_job_execution_minutes=max_job_execution_minutes,
total_max_jobs=total_max_jobs, **_)
self._param_iterator = None
def create_job(self):
# type: () -> Optional[TrainsJob]
"""
Create a new job if needed. return the newly created job.
If no job needs to be created, return None
@@ -466,6 +534,7 @@ class GridSearch(SearchStrategy):
return self.helper_create_job(base_task_id=self._base_task_id, parameter_override=parameters)
def monitor_job(self, job):
# type: (TrainsJob) -> bool
"""
Check if the job needs to be aborted or already completed
if return False, the job was aborted / completed, and should be taken off the current job list
@@ -480,6 +549,7 @@ class GridSearch(SearchStrategy):
return not job.is_stopped()
def _next_configuration(self):
# type: () -> Mapping[str, str]
def param_iterator_fn():
hyper_params_values = [p.to_list() for p in self._hyper_parameters]
for state in product(*hyper_params_values):
@@ -499,9 +569,19 @@ class RandomSearch(SearchStrategy):
# Number of already chosen random samples before assuming we covered the entire hyper-parameter space
_hp_space_cover_samples = 42
def __init__(self, base_task_id, hyper_parameters, objective_metric,
execution_queue, num_concurrent_workers, pool_period_min=2.0,
max_job_execution_minutes=None, total_max_jobs=None, **_):
def __init__(
self,
base_task_id, # type: str
hyper_parameters, # type: Sequence[Parameter]
objective_metric, # type: Objective
execution_queue, # type: str
num_concurrent_workers, # type: int
pool_period_min=2., # type: float
max_job_execution_minutes=None, # type: Optional[float]
total_max_jobs=None, # type: Optional[int]
**_ # type: Any
):
# type: (...) -> RandomSearch
"""
Initialize a random search optimizer
@@ -518,10 +598,11 @@ class RandomSearch(SearchStrategy):
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min,
max_job_execution_minutes=max_job_execution_minutes, total_max_jobs=total_max_jobs)
max_job_execution_minutes=max_job_execution_minutes, total_max_jobs=total_max_jobs, **_)
self._hyper_parameters_collection = set()
def create_job(self):
# type: () -> Optional[TrainsJob]
"""
Create a new job if needed. return the newly created job.
If no job needs to be created, return None
@@ -551,6 +632,7 @@ class RandomSearch(SearchStrategy):
return self.helper_create_job(base_task_id=self._base_task_id, parameter_override=parameters)
def monitor_job(self, job):
# type: (TrainsJob) -> bool
"""
Check if the job needs to be aborted or already completed
if return False, the job was aborted / completed, and should be taken off the current job list
@@ -572,18 +654,22 @@ class HyperParameterOptimizer(object):
"""
_tag = 'optimization'
def __init__(self, base_task_id,
hyper_parameters,
objective_metric_title,
objective_metric_series,
objective_metric_sign='min',
optimizer_class=RandomSearch,
max_number_of_concurrent_tasks=10,
execution_queue='default',
optimization_time_limit=None,
auto_connect_task=True,
always_create_task=False,
**optimizer_kwargs):
def __init__(
self,
base_task_id, # type: str
hyper_parameters, # type: Sequence[Parameter]
objective_metric_title, # type: str
objective_metric_series, # type: str
objective_metric_sign='min', # type: Union['min', 'max', 'min_global', 'max_global']
optimizer_class=RandomSearch, # type: SearchStrategy
max_number_of_concurrent_tasks=10, # type: int
execution_queue='default', # type: str
optimization_time_limit=None, # type: Optional[float]
auto_connect_task=True, # type: bool
always_create_task=False, # type: bool
**optimizer_kwargs # type: Any
):
# type: (...) -> HyperParameterOptimizer
"""
Create a new hyper-parameter controller. The newly created object will launch and monitor the new experiments.
@@ -702,6 +788,7 @@ class HyperParameterOptimizer(object):
self.set_time_limit(in_minutes=opts['optimization_time_limit'])
def get_num_active_experiments(self):
# type: () -> int
"""
Return the number of current active experiments
@@ -712,6 +799,7 @@ class HyperParameterOptimizer(object):
return len(self.optimizer.get_running_jobs())
def get_active_experiments(self):
# type: () -> Sequence[Task]
"""
Return a list of Tasks of the current active experiments
@@ -722,6 +810,7 @@ class HyperParameterOptimizer(object):
return [j.task for j in self.optimizer.get_running_jobs()]
def start(self, job_complete_callback=None):
# type: (Optional[Callable[[str, float, int, dict, str], None]]) -> bool
"""
Start the HyperParameterOptimizer controller.
If the calling process is stopped, the controller will stop as well.
@@ -755,6 +844,7 @@ class HyperParameterOptimizer(object):
return True
def stop(self, timeout=None):
# type: (Optional[float]) -> ()
"""
Stop the HyperParameterOptimizer controller and optimization thread,
@@ -770,7 +860,7 @@ class HyperParameterOptimizer(object):
# wait for optimizer thread
if timeout is not None:
_thread.join(timeout=timeout*60.)
_thread.join(timeout=timeout * 60.)
# stop all running tasks:
for j in self.optimizer.get_running_jobs():
@@ -782,6 +872,7 @@ class HyperParameterOptimizer(object):
self._thread_reporter.join()
def is_active(self):
# type: () -> bool
"""
Return True if the optimization procedure is still running
Note, if the daemon thread has not yet started, it will still return True
@@ -791,6 +882,7 @@ class HyperParameterOptimizer(object):
return self._stop_event is None or self._thread is not None
def is_running(self):
# type: () -> bool
"""
Return True if the optimization controller is running
@@ -799,6 +891,7 @@ class HyperParameterOptimizer(object):
return self._thread is not None
def wait(self, timeout=None):
# type: (Optional[float]) -> bool
"""
Wait for the optimizer to finish.
It will not stop the optimizer in any case. Call stop() to terminate the optimizer.
@@ -825,6 +918,7 @@ class HyperParameterOptimizer(object):
return True
def set_time_limit(self, in_minutes=None, specific_time=None):
# type: (Optional[float], Optional[datetime]) -> ()
"""
Set a time limit for the HyperParameterOptimizer controller,
i.e. if we reached the time limit, stop the optimization process
@@ -835,9 +929,10 @@ class HyperParameterOptimizer(object):
if specific_time:
self.optimization_timeout = specific_time.timestamp()
else:
self.optimization_timeout = (in_minutes*60.) + time() if in_minutes else None
self.optimization_timeout = (in_minutes * 60.) + time() if in_minutes else None
def get_time_limit(self):
# type: () -> datetime
"""
Return the controller optimization time limit.
@@ -846,6 +941,7 @@ class HyperParameterOptimizer(object):
return datetime.fromtimestamp(self.optimization_timeout)
def elapsed(self):
# type: () -> float
"""
Return minutes elapsed from controller stating time stamp
@@ -853,9 +949,10 @@ class HyperParameterOptimizer(object):
"""
if self.optimization_start_time is None:
return -1.0
return (time() - self.optimization_start_time)/60.
return (time() - self.optimization_start_time) / 60.
def reached_time_limit(self):
# type: () -> bool
"""
Return True if we passed the time limit. Function returns immediately, it does not wait for the optimizer.
@@ -869,6 +966,7 @@ class HyperParameterOptimizer(object):
return time() > self.optimization_timeout
def get_top_experiments(self, top_k):
# type: (int) -> Sequence[Task]
"""
Return a list of Tasks of the top performing experiments, based on the controller Objective object
@@ -880,9 +978,16 @@ class HyperParameterOptimizer(object):
return self.optimizer.get_top_experiments(top_k=top_k)
def get_optimizer(self):
# type: () -> SearchStrategy
"""
Return the currently used optimizer object
:return SearchStrategy: Used SearchStrategy object
"""
return self.optimizer
def set_default_job_class(self, job_class):
# type: (TrainsJob) -> ()
"""
Set the Job class to use when the optimizer spawns new Jobs
@@ -891,6 +996,7 @@ class HyperParameterOptimizer(object):
self.optimizer.set_job_class(job_class)
def set_report_period(self, report_period_minutes):
# type: (float) -> ()
"""
Set reporting period in minutes, for the accumulated objective report
This report is sent on the Optimizer Task, and collects objective metric from all running jobs.
@@ -900,6 +1006,7 @@ class HyperParameterOptimizer(object):
self._report_period_min = float(report_period_minutes)
def _connect_args(self, optimizer_class=None, hyper_param_configuration=None, **kwargs):
# type: (SearchStrategy, dict, Any) -> (SearchStrategy, list, dict)
if not self._task:
logger.warning('Auto Connect turned on but no Task was found, '
'hyper-parameter optimization argument logging disabled')
@@ -937,6 +1044,7 @@ class HyperParameterOptimizer(object):
return optimizer_class, configuration_dict['parameter_optimization_space'], arguments['opt']
def _daemon(self):
# type: () -> ()
"""
implement the main pooling thread, calling loop every self.pool_period_minutes minutes
"""
@@ -944,6 +1052,7 @@ class HyperParameterOptimizer(object):
self._thread = None
def _report_daemon(self):
# type: () -> ()
worker_to_series = {}
title, series = self.objective_metric.get_objective_metric()
title = '{}/{}'.format(title, series)
@@ -956,8 +1065,8 @@ class HyperParameterOptimizer(object):
timeout = self.optimization_timeout - time() if self.optimization_timeout else 0.
if timeout >= 0:
timeout = min(self._report_period_min*60., timeout if timeout else self._report_period_min*60.)
print('Progress report #{} completed, sleeping for {} minutes'.format(counter, timeout/60.))
timeout = min(self._report_period_min * 60., timeout if timeout else self._report_period_min * 60.)
print('Progress report #{} completed, sleeping for {} minutes'.format(counter, timeout / 60.))
if self._stop_event.wait(timeout=timeout):
# wait for one last report
timeout = -1

View File

@@ -1,8 +1,7 @@
import sys
from itertools import product
from random import Random as BaseRandom
import numpy as np
from typing import Mapping, Any, Sequence, Optional, Union
class RandomSeed(object):
@@ -14,6 +13,7 @@ class RandomSeed(object):
@staticmethod
def set_random_seed(seed=1337):
# type: (int) -> ()
"""
Set global seed for all hyper parameter strategy random number sampling
@@ -24,6 +24,7 @@ class RandomSeed(object):
@staticmethod
def get_random_seed():
# type: () -> int
"""
Get the global seed for all hyper parameter strategy random number sampling
@@ -38,9 +39,16 @@ class Parameter(RandomSeed):
"""
def __init__(self, name):
# type: (Optional[str]) -> Parameter
"""
Create a new Parameter for hyper-parameter optimization
:param str name: give a name to the parameter, this is the parameter name that will be passed to a Task
"""
self.name = name
def get_value(self):
# type: () -> Mapping[str, Any]
"""
Return a dict with the Parameter name and a sampled value for the Parameter
@@ -49,16 +57,20 @@ class Parameter(RandomSeed):
pass
def to_list(self):
# type: () -> Sequence[Mapping[str, Any]]
"""
Return a list of all the valid values of the Parameter
:return list: list of dicts {name: values}
:return list: list of dicts {name: value}
"""
pass
def to_dict(self):
# type: () -> Mapping[str, Union[str, Parameter]]
"""
Return a dict representation of the parameter object
:return: dict
Return a dict representation of the parameter object. Used for serialization of the Parameter object.
:return dict: dict representation of the object (serialization)
"""
serialize = {'__class__': str(self.__class__).split('.')[-1][:-2]}
serialize.update(dict(((k, v.to_dict() if hasattr(v, 'to_dict') else v) for k, v in self.__dict__.items())))
@@ -66,9 +78,11 @@ class Parameter(RandomSeed):
@classmethod
def from_dict(cls, a_dict):
# type: (Mapping[str, str]) -> Parameter
"""
Construct parameter object from a dict representation
:return: Parameter
Construct parameter object from a dict representation (deserialize from dict)
:return Parameter: Parameter object
"""
a_dict = a_dict.copy()
a_cls = a_dict.pop('__class__', None)
@@ -89,7 +103,15 @@ class UniformParameterRange(Parameter):
Uniform randomly sampled Hyper-Parameter object
"""
def __init__(self, name, min_value, max_value, step_size=None, include_max_value=True):
def __init__(
self,
name, # type: str
min_value, # type: float
max_value, # type: float
step_size=None, # type: Optional[float]
include_max_value=True # type: bool
):
# type: (...) -> UniformParameterRange
"""
Create a parameter to be sampled by the SearchStrategy
@@ -106,6 +128,7 @@ class UniformParameterRange(Parameter):
self.include_max = include_max_value
def get_value(self):
# type: () -> Mapping[str, Any]
"""
Return uniformly sampled value based on object sampling definitions
@@ -117,14 +140,15 @@ class UniformParameterRange(Parameter):
return {self.name: self.min_value + (self._random.randrange(start=0, stop=round(steps)) * self.step_size)}
def to_list(self):
# type: () -> Sequence[Mapping[str, float]]
"""
Return a list of all the valid values of the Parameter
if self.step_size is not defined, return 100 points between minmax values
:return list: list of dicts {name: float}
"""
values = np.arange(start=self.min_value, stop=self.max_value,
step=self.step_size or (self.max_value - self.min_value) / 100.,
dtype=type(self.min_value)).tolist()
step_size = self.step_size or (self.max_value - self.min_value) / 100.
steps = (self.max_value - self.min_value) / self.step_size
values = [v*step_size for v in range(0, int(steps))]
if self.include_max and (not values or values[-1] < self.max_value):
values.append(self.max_value)
return [{self.name: v} for v in values]
@@ -152,6 +176,7 @@ class UniformIntegerParameterRange(Parameter):
self.include_max = include_max_value
def get_value(self):
# type: () -> Mapping[str, Any]
"""
Return uniformly sampled value based on object sampling definitions
@@ -162,10 +187,12 @@ class UniformIntegerParameterRange(Parameter):
stop=self.max_value + (0 if not self.include_max else self.step_size))}
def to_list(self):
# type: () -> Sequence[Mapping[str, int]]
"""
Return a list of all the valid values of the Parameter
if self.step_size is not defined, return 100 points between minmax values
:return list: list of dicts {name: float}
:return list: list of dicts {name: int}
"""
values = list(range(self.min_value, self.max_value, self.step_size))
if self.include_max and (not values or values[-1] < self.max_value):
@@ -189,6 +216,7 @@ class DiscreteParameterRange(Parameter):
self.values = values
def get_value(self):
# type: () -> Mapping[str, Any]
"""
Return uniformly sampled value from the valid list of values
@@ -197,8 +225,10 @@ class DiscreteParameterRange(Parameter):
return {self.name: self._random.choice(self.values)}
def to_list(self):
# type: () -> Sequence[Mapping[str, Any]]
"""
Return a list of all the valid values of the Parameter
:return list: list of dicts {name: value}
"""
return [{self.name: v} for v in self.values]
@@ -210,6 +240,7 @@ class ParameterSet(Parameter):
"""
def __init__(self, parameter_combinations=()):
# type: (Sequence[Mapping[str, Union[float, int, str, Parameter]]]) -> ParameterSet
"""
Uniformly sample values form a list of discrete options (combinations) of parameters
@@ -225,6 +256,7 @@ class ParameterSet(Parameter):
self.values = parameter_combinations
def get_value(self):
# type: () -> Mapping[str, Any]
"""
Return uniformly sampled value from the valid list of values
@@ -233,8 +265,10 @@ class ParameterSet(Parameter):
return self._get_value(self._random.choice(self.values))
def to_list(self):
# type: () -> Sequence[Mapping[str, Any]]
"""
Return a list of all the valid values of the Parameter
:return list: list of dicts {name: value}
"""
combinations = []
@@ -253,6 +287,7 @@ class ParameterSet(Parameter):
@staticmethod
def _get_value(combination):
# type: (dict) -> dict
value_dict = {}
for k, v in combination.items():
if isinstance(v, Parameter):

View File

@@ -23,12 +23,12 @@ class ApiServiceProxy(object):
if not ApiServiceProxy._max_available_version:
from ..backend_api import services
ApiServiceProxy._max_available_version = max([
Version(name[1:].replace("_", "."))
for name in [
module_name
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
]])
Version(name[1:].replace("_", "."))
for name in [
module_name
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
]])
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
self.__dict__["__wrapped_version__"] = Session.api_version

View File

@@ -59,7 +59,7 @@ class DataModel(object):
props = {}
for c in cls.__mro__:
props.update({k: getattr(v, 'name', k) for k, v in vars(c).items()
if isinstance(v, property)})
if isinstance(v, property)})
cls._data_props_list = props
return props.copy()
@@ -150,6 +150,7 @@ class NonStrictDataModelMixin(object):
:summary: supplies an __init__ method that warns about unused keywords
"""
def __init__(self, **kwargs):
# unexpected = [key for key in kwargs if not key.startswith('_')]
# if unexpected:

View File

@@ -312,8 +312,7 @@ class Config(object):
return ConfigFactory.parse_file(file_path)
except ParseSyntaxException as ex:
msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format(
file_path, ex
)
file_path, ex)
six.reraise(
ConfigurationError,
ConfigurationError(msg, file_path=file_path),

View File

@@ -273,7 +273,7 @@ class UploadEvent(MetricsEventAdapter):
image_data = np.atleast_3d(image_data)
if image_data.dtype != np.uint8:
if np.issubdtype(image_data.dtype, np.floating) and image_data.max() <= 1.0:
image_data = (image_data*255).astype(np.uint8)
image_data = (image_data * 255).astype(np.uint8)
else:
image_data = image_data.astype(np.uint8)
shape = image_data.shape
@@ -318,7 +318,7 @@ class UploadEvent(MetricsEventAdapter):
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
# make sure we preserve local path root
if e_storage_uri.startswith('/'):
url = '/'+url
url = '/' + url
if quote_uri:
url = quote_url(url)

View File

@@ -155,7 +155,7 @@ class Metrics(InterfaceBase):
# upload files
def upload(e):
upload_uri = e.upload_uri or storage_uri
try:
storage = self._get_storage(upload_uri)
url = storage.upload_from_stream(e.stream, e.url, retries=self._file_upload_retries)
@@ -234,4 +234,3 @@ class Metrics(InterfaceBase):
pool.join()
except:
pass

View File

@@ -210,7 +210,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
))
self.reload()
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
uri=None, framework=None, iteration=None):
if tags:
extra = {'system_tags': tags or self.data.system_tags} \
@@ -318,7 +318,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
callback = partial(
self._complete_update_for_task, task_id=task_id, name=name, comment=comment, tags=tags,
override_model_id=override_model_id, cb=cb)
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable, cb=callback)
uri = self._upload_model(model_file, target_filename=target_filename,
async_enable=async_enable, cb=callback)
return uri
else:
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)

View File

@@ -16,6 +16,7 @@ class _Arguments(object):
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, arguments, *args, **kwargs):
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
self._arguments = arguments
@@ -27,6 +28,7 @@ class _Arguments(object):
class _ProxyDictReadOnly(dict):
""" Dictionary wrapper that prevents modifications to the dictionary """
def __init__(self, arguments, *args, **kwargs):
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
self._arguments = arguments

View File

@@ -16,7 +16,7 @@ buffer_capacity = config.get('log.task_log_buffer_capacity', 100)
class TaskHandler(BufferingHandler):
__flush_max_history_seconds = 30.
__wait_for_flush_timeout = 10.
__max_event_size = 1024*1024
__max_event_size = 1024 * 1024
__once = False
@property

View File

@@ -24,7 +24,7 @@ class ScriptInfoError(Exception):
class ScriptRequirements(object):
_max_requirements_size = 512*1024
_max_requirements_size = 512 * 1024
def __init__(self, root_folder):
self._root_folder = root_folder
@@ -365,7 +365,7 @@ class ScriptInfo(object):
@classmethod
def _get_jupyter_notebook_filename(cls):
if not (sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or
if not (sys.argv[0].endswith(os.path.sep + 'ipykernel_launcher.py') or
sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \
or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
return None

View File

@@ -119,7 +119,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._curr_label_stats = {}
self._raise_on_validation_errors = raise_on_validation_errors
self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
)
self._app_server = None
self._files_server = None
@@ -683,7 +683,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
"Only builtin types ({}) are allowed for values (got {})".format(
', '.join(t.__name__ for t in self._parameters_allowed_types),
', '.join('%s=>%s' % p for p in not_allowed.items())),
)
)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
@@ -1300,5 +1300,5 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def __is_subprocess(cls):
# notice this class function is called from Task.ExitHooks, do not rename/move it.
is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
return is_subprocess

View File

@@ -1 +0,0 @@

View File

@@ -199,6 +199,7 @@ class Artifacts(object):
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, artifacts_manager, *args, **kwargs):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
@@ -303,8 +304,8 @@ class Artifacts(object):
artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__())
override_filename_ext_in_uri = '.npz'
override_filename_in_uri = name+override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True
@@ -314,7 +315,7 @@ class Artifacts(object):
artifact_type_data.preview = str(artifact_object.__repr__())
override_filename_ext_in_uri = self._save_format
override_filename_in_uri = name
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
os.close(fd)
artifact_object.to_csv(local_filename, compression=self._compression)
delete_after_upload = True
@@ -325,7 +326,7 @@ class Artifacts(object):
artifact_type_data.preview = desc[1:desc.find(' at ')]
override_filename_ext_in_uri = '.png'
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
os.close(fd)
artifact_object.save(local_filename)
delete_after_upload = True
@@ -335,7 +336,7 @@ class Artifacts(object):
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
override_filename_ext_in_uri = '.json'
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
os.write(fd, bytes(preview.encode()))
os.close(fd)
artifact_type_data.preview = preview
@@ -374,7 +375,7 @@ class Artifacts(object):
override_filename_ext_in_uri = '.zip'
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
fd, zip_file = mkstemp(
prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri
prefix=quote(folder.parts[-1], safe="") + '.', suffix=override_filename_ext_in_uri
)
try:
artifact_type_data.content_type = 'application/zip'
@@ -571,7 +572,8 @@ class Artifacts(object):
artifact_type_data.data_hash = current_sha2
artifact_type_data.content_type = "text/csv"
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
artifact_type_data.preview = str(pd_artifact.__repr__(
)) + '\n\n' + self._get_statistics({name: pd_artifact})
artifact.type_data = artifact_type_data
artifact.uri = uri
@@ -657,11 +659,11 @@ class Artifacts(object):
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
name=name, shape=shape, unique=len(unique_hash), percentage=100 * len(unique_hash) / float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i + 1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
name2=name2, intersection=intersection, percentage=100 * intersection / float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:

View File

@@ -230,8 +230,8 @@ class EventTrainsWriter(object):
:return: (str, str) variant and metric
"""
splitted_tag = tag.split(split_char)
if auto_reduce_num_split and num_split_parts > len(splitted_tag)-1:
num_split_parts = max(1, len(splitted_tag)-1)
if auto_reduce_num_split and num_split_parts > len(splitted_tag) - 1:
num_split_parts = max(1, len(splitted_tag) - 1)
series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
@@ -356,7 +356,7 @@ class EventTrainsWriter(object):
val = val[:, :, [0, 1, 2]]
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]'
% (width, height, color_channels))
% (width, height, color_channels))
val = None
return val
@@ -525,7 +525,7 @@ class EventTrainsWriter(object):
stream = BytesIO(audio_data)
if values:
file_extension = guess_extension(values['contentType']) or \
'.{}'.format(values['contentType'].split('/')[-1])
'.{}'.format(values['contentType'].split('/')[-1])
else:
# assume wav as default
file_extension = '.wav'
@@ -548,7 +548,7 @@ class EventTrainsWriter(object):
wraparound_counter = EventTrainsWriter._title_series_wraparound_counter[key]
# we decide on wrap around if the current step is less than 10% of the previous step
# notice since counter is int and we want to avoid rounding error, we have double check in the if
if step < wraparound_counter['last_step'] and step < 0.9*wraparound_counter['last_step']:
if step < wraparound_counter['last_step'] and step < 0.9 * wraparound_counter['last_step']:
# adjust step base line
wraparound_counter['adjust_counter'] += wraparound_counter['last_step'] + (1 if step <= 0 else step)
@@ -582,7 +582,8 @@ class EventTrainsWriter(object):
msg_dict.pop('wallTime', None)
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
keys_list = ', '.join(keys_list)
LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list)
LoggerRoot.get_base_logger(TensorflowBinding).debug(
'event summary not found, message type unsupported: %s' % keys_list)
return
value_dicts = summary.get('value')
walltime = walltime or msg_dict.get('step')
@@ -594,7 +595,8 @@ class EventTrainsWriter(object):
step = int(event.step)
else:
step = 0
LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step))
LoggerRoot.get_base_logger(TensorflowBinding).debug(
'Received event without step, assuming step = {}'.format(step))
else:
step = int(step)
self._max_step = max(self._max_step, step)
@@ -1036,7 +1038,7 @@ class PatchModelCheckPointCallback(object):
name=defaults_dict.get('name'),
comment=defaults_dict.get('comment'),
label_enumeration=defaults_dict.get('label_enumeration') or
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
framework=Framework.keras,
)
output_model.set_upload_destination(
@@ -1206,7 +1208,7 @@ class PatchTensorFlowEager(object):
for i in range(2, img_data_np.size):
img_data = {'width': -1, 'height': -1,
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag)
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
event_writer._add_image(tag=image_tag,
step=int(step.numpy()) if not isinstance(step, int) else step,
img_data=img_data)
@@ -1291,7 +1293,7 @@ class PatchKerasModelIO(object):
PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
PatchKerasModelIO._from_config))
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)

View File

@@ -249,7 +249,7 @@ class PatchedMatplotlib:
# check if this is an imshow
if hasattr(stored_figure, '_trains_is_imshow'):
# flag will be cleared when calling clf() (object will be replaced)
stored_figure._trains_is_imshow = max(0, stored_figure._trains_is_imshow-1)
stored_figure._trains_is_imshow = max(0, stored_figure._trains_is_imshow - 1)
force_save_as_image = True
# get current figure
mpl_fig = stored_figure.canvas.figure # plt.gcf()
@@ -342,7 +342,7 @@ class PatchedMatplotlib:
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
facecolor=None)
buffer_.seek(0)
fd, image = mkstemp(suffix='.'+image_format)
fd, image = mkstemp(suffix='.' + image_format)
os.write(fd, buffer_.read())
os.close(fd)

View File

@@ -1 +0,0 @@

View File

@@ -127,7 +127,8 @@ def main():
# this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(
'demoapp.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('app.'):
# this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
@@ -138,7 +139,8 @@ def main():
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(
'demoapi.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('api.'):
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
parsed_host.netloc))

View File

@@ -42,8 +42,9 @@ def _log_stderr(name, fnc, args, kwargs, is_return):
# get a nicer thread id
h = int(__thread_id())
ts = time.time() - __trace_start
__stream_write('{}{:<9.3f}:{:5}:{:8x}: [{}] {}\n'.format('-' if is_return else '',
ts, os.getpid(), h, threading.current_thread().name, t))
__stream_write('{}{:<9.3f}:{:5}:{:8x}: [{}] {}\n'.format(
'-' if is_return else '', ts, os.getpid(),
h, threading.current_thread().name, t))
if __stream_flush:
__stream_flush()
except:
@@ -167,7 +168,7 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None):
# check if this is even in our module
if hasattr(fnc, '__module__') and fnc.__module__ != module.__module__:
pass # print('not ours {} {}'.format(module, fnc))
elif hasattr(fnc, '__qualname__') and fnc.__qualname__.startswith(module.__name__+'.'):
elif hasattr(fnc, '__qualname__') and fnc.__qualname__.startswith(module.__name__ + '.'):
if isinstance(module.__dict__[fn], classmethod):
setattr(module, fn, _traced_call_cls(prefix + fn, fnc))
elif isinstance(module.__dict__[fn], staticmethod):
@@ -176,7 +177,7 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None):
setattr(module, fn, _traced_call_method(prefix + fn, fnc))
else:
# probably not ours hopefully static function
if hasattr(fnc, '__qualname__') and not fnc.__qualname__.startswith(module.__name__+'.'):
if hasattr(fnc, '__qualname__') and not fnc.__qualname__.startswith(module.__name__ + '.'):
pass # print('not ours {} {}'.format(module, fnc))
else:
# we should not get here
@@ -232,7 +233,7 @@ def trace_trains(stream=None, level=1):
# store to stream
__stream_write(msg)
__stream_write('{:9}:{:5}:{:8}: {:14}\n'.format('seconds', 'pid', 'tid', 'self'))
__stream_write('{:9}:{:5}:{:8}:{:15}\n'.format('-'*9, '-'*5, '-'*8, '-'*15))
__stream_write('{:9}:{:5}:{:8}:{:15}\n'.format('-' * 9, '-' * 5, '-' * 8, '-' * 15))
__trace_start = time.time()
_patch_module('trains')
@@ -267,6 +268,7 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi
:param specify_pids: optional list of pids to include
"""
from glob import glob
def hash_line(a_line):
return hash(':'.join(a_line.split(':')[1:]))
@@ -304,7 +306,7 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi
by_time = {}
for p, tids in pids.items():
for t, lines in tids.items():
ts = float(lines[-1].split(':')[0].strip()) + 0.000001*len(by_time)
ts = float(lines[-1].split(':')[0].strip()) + 0.000001 * len(by_time)
if print_orphans:
for i, l in enumerate(lines):
if i > 0 and hash_line(l) in orphan_calls:
@@ -313,7 +315,7 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi
out_stream = open(stream, 'w') if isinstance(stream, str) else stream
for k in sorted(by_time.keys()):
out_stream.write(by_time[k]+'\n')
out_stream.write(by_time[k] + '\n')
if isinstance(stream, str):
out_stream.close()
@@ -322,6 +324,7 @@ def end_of_program():
# stub
pass
if __name__ == '__main__':
# from trains import Task
# task = Task.init(project_name="examples", task_name="trace test")

View File

@@ -522,15 +522,15 @@ class Logger(object):
"""
# check if multiple series
multi_series = (
isinstance(scatter, list)
and (
isinstance(scatter[0], np.ndarray)
or (
scatter[0]
and isinstance(scatter[0], list)
and isinstance(scatter[0][0], list)
)
isinstance(scatter, list)
and (
isinstance(scatter[0], np.ndarray)
or (
scatter[0]
and isinstance(scatter[0], list)
and isinstance(scatter[0][0], list)
)
)
)
if not multi_series:

View File

@@ -109,9 +109,7 @@ class BaseModel(object):
"""
The Id (system UUID) of the model.
:return: The model id.
:rtype: str
:return str: The model id.
"""
return self._get_model_data().id
@@ -121,9 +119,7 @@ class BaseModel(object):
"""
The name of the model.
:return: The model name.
:rtype: str
:return str: The model name.
"""
return self._get_model_data().name
@@ -143,9 +139,7 @@ class BaseModel(object):
"""
The comment for the model. Also, use for a model description.
:return: The model comment / description.
:rtype: str
:return str: The model comment / description.
"""
return self._get_model_data().comment
@@ -165,9 +159,7 @@ class BaseModel(object):
"""
A list of tags describing the model.
:return: The list of tags.
:rtype: list(str)
:return list(str): The list of tags.
"""
return self._get_model_data().tags
@@ -189,9 +181,7 @@ class BaseModel(object):
"""
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
:return: The configuration.
:rtype: str
:return str: The configuration.
"""
return _Model._unwrap_design(self._get_model_data().design)
@@ -202,9 +192,7 @@ class BaseModel(object):
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
For example, prototxt, an ini file, or Python code to evaluate.
:return: The configuration.
:rtype: dict
:return str: The configuration.
"""
return self._text_to_config_dict(self.config_text)
@@ -215,9 +203,7 @@ class BaseModel(object):
The label enumeration of string (label) to integer (value) pairs.
:return: A dictionary containing labels enumeration, where the keys are labels and the values as integers.
:rtype: dict
:return dict: A dictionary containing labels enumeration, where the keys are labels and the values as integers.
"""
return self._get_model_data().labels
@@ -266,9 +252,7 @@ class BaseModel(object):
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return: The locally stored file.
:rtype: str
:return str: The locally stored file.
"""
# download model (synchronously) and return local file
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
@@ -288,8 +272,6 @@ class BaseModel(object):
raise ValueError, otherwise return None on failure and output log warning.
:return: The model weights, or a list of the locally stored filenames.
:rtype: package or path
"""
# check if model was packaged
if self._package_tag not in self._get_model_data().tags:
@@ -508,7 +490,7 @@ class InputModel(Model):
:type config_text: unconstrained text string
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
but not both.
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs. (Optional)
:param dict label_enumeration: Optional label enumeration dictionary of string (label) to integer (value) pairs.
For example:
@@ -538,8 +520,6 @@ class InputModel(Model):
:type framework: str or Framework object
:return: The imported model or existing model (see above).
:rtype: A model object.
"""
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
weights_url = StorageHelper.conform_url(weights_url)
@@ -578,7 +558,7 @@ class InputModel(Model):
from .task import Task
task = Task.current_task()
if task:
comment = 'Imported by task id: {}'.format(task.id) + ('\n'+comment if comment else '')
comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '')
project_id = task.project
task_id = task.id
else:
@@ -776,9 +756,7 @@ class OutputModel(BaseModel):
"""
Get the published state of this model.
:return: ``True`` if the model is published, ``False`` otherwise.
:rtype: bool
:return bool: ``True`` if the model is published, ``False`` otherwise.
"""
if not self.id:
return False
@@ -790,9 +768,7 @@ class OutputModel(BaseModel):
"""
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
:return: The configuration.
:rtype: str
:return str: The configuration.
"""
return _Model._unwrap_design(self._get_model_data().design)
@@ -811,9 +787,7 @@ class OutputModel(BaseModel):
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
configuration. For example, from prototxt to ini file or python code to evaluate.
:return: The configuration.
:rtype: dict
:return dict: The configuration.
"""
return self._text_to_config_dict(self.config_text)
@@ -842,9 +816,7 @@ class OutputModel(BaseModel):
'person': 1
}
:return: The label enumeration.
:rtype: dict
:return dict: The label enumeration.
"""
return self._get_model_data().labels
@@ -889,7 +861,8 @@ class OutputModel(BaseModel):
Create a new model and immediately connect it to a task.
We do not allow for Model creation without a task, so we always keep track on how we created the models
In remote execution, Model parameters can be overridden by the Task (such as model configuration & label enumerator)
In remote execution, Model parameters can be overridden by the Task
(such as model configuration & label enumerator)
:param task: The Task object with which the OutputModel object is associated.
:type task: Task
@@ -979,12 +952,12 @@ class OutputModel(BaseModel):
if running_remotely() and task.is_main_task():
if self._floating_data:
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \
self._floating_data.design
self._floating_data.design
self._floating_data.labels = self._task.get_labels_enumeration() or \
self._floating_data.labels
self._floating_data.labels
elif self._base_model:
self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or
self._base_model.design)
self._base_model.design)
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
elif self._floating_data is not None:
@@ -1005,8 +978,8 @@ class OutputModel(BaseModel):
def set_upload_destination(self, uri):
# type: (str) -> None
"""
Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include
S3, Google Cloud Storage), and file locations.
Set the URI of the storage destination for uploaded model weight files.
Supported storage destinations include S3, Google Cloud Storage), and file locations.
Using this method, files uploads are separate and then a link to each is stored in the model object.
@@ -1021,12 +994,10 @@ class OutputModel(BaseModel):
- ``s3://bucket/directory/``
- ``file:///tmp/debug/``
:return: The status of whether the storage destination schema is supported.
:return bool: The status of whether the storage destination schema is supported.
- ``True`` - The storage destination scheme is supported.
- ``False`` - The storage destination scheme is not supported.
:rtype: bool
"""
if not uri:
return
@@ -1063,8 +1034,8 @@ class OutputModel(BaseModel):
.. note::
Uploading the model is a background process. A call to this method returns immediately.
:param str weights_filename: The name of the locally stored weights file to upload. Specify ``weights_filename``
or ``register_uri``, but not both.
:param str weights_filename: The name of the locally stored weights file to upload.
Specify ``weights_filename`` or ``register_uri``, but not both.
:param str upload_uri: The URI of the storage destination for model weights upload. The default value
is the previously used URI. (Optional)
:param str target_filename: The newly created filename in the storage destination location. The default value
@@ -1083,9 +1054,7 @@ class OutputModel(BaseModel):
- ``True`` - Update model comment (Default)
- ``False`` - Do not update
:return: The uploaded URI.
:rtype: str
:return str: The uploaded URI.
"""
def delete_previous_weights_file(filename=weights_filename):
@@ -1120,7 +1089,8 @@ class OutputModel(BaseModel):
if not model:
raise ValueError('Failed creating internal output model')
# select the correct file extension based on the framework, or update the framework based on the file extension
# select the correct file extension based on the framework,
# or update the framework based on the file extension
framework, file_ext = Framework._get_file_ext(
framework=self._get_model_data().framework,
filename=target_filename or weights_filename or register_uri
@@ -1211,13 +1181,12 @@ class OutputModel(BaseModel):
:param int iteration: The iteration number.
:return: The uploaded URI for the weights package.
:rtype: str
:return str: The uploaded URI for the weights package.
"""
# create list of files
if (not weights_filenames and not weights_path) or (weights_filenames and weights_path):
raise ValueError('Model update weights package should get either directory path to pack or a list of files')
raise ValueError('Model update weights package should get either '
'directory path to pack or a list of files')
if not weights_filenames:
weights_filenames = list(map(six.text_type, Path(weights_path).rglob('*')))
@@ -1276,14 +1245,13 @@ class OutputModel(BaseModel):
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
but not both.
:return: The status of the update.
:return bool: The status of the update.
- ``True`` - Update successful.
- ``False`` - Update not successful.
:rtype: bool
"""
if not self._validate_update():
return
return False
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)

View File

@@ -575,27 +575,27 @@ class StorageHelper(object):
def list(self, prefix=None):
"""
List entries in the helper base path.
Return a list of names inside this helper base path. The base path is
determined at creation time and is specific for each storage medium.
For Google Storage and S3 it is the bucket of the path.
For local files it is the root directory.
This operation is not supported for http and https protocols.
:param prefix: If None, return the list as described above. If not, it
must be a string - the path of a sub directory under the base path.
the returned list will include only objects under that subdir.
:return: List of strings - the paths of all the objects in the storage base
path under prefix. Listed relative to the base path.
"""
if prefix:
if prefix.startswith(self._base_url):
prefix = prefix[len(self.base_url):].lstrip("/")
try:
res = self._driver.list_container_objects(self._container, ex_prefix=prefix)
except TypeError:
@@ -943,7 +943,7 @@ class StorageHelper(object):
cb(dest_path)
except Exception as e:
self._log.warning("Exception on upload callback: %s" % str(e))
return dest_path
def _get_object(self, path):
@@ -1012,8 +1012,8 @@ class _HttpDriver(_Driver):
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
url = object_name[:object_name.index('/')]
url_path = object_name[len(url)+1:]
full_url = container.name+url
url_path = object_name[len(url) + 1:]
full_url = container.name + url
# when sending data in post, there is no connection timeout, just an entire upload timeout
timeout = self.timeout[-1]
if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'):
@@ -1668,7 +1668,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
def callback_func(current, total):
if callback:
chunk = current-download_done.counter
chunk = current - download_done.counter
download_done.counter += chunk
callback(chunk)
if current >= total:

View File

@@ -618,16 +618,22 @@ class Task(_Task):
return self.get_models()
@classmethod
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
# type: (Optional[Task], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> Task
def clone(
cls,
source_task=None, # type: Optional[Union[Task, str]]
name=None, # type: Optional[str]
comment=None, # type: Optional[str]
parent=None, # type: Optional[str]
project=None, # type: Optional[str]
):
# type: (...) -> Task
"""
Create a duplicate (a clone) of a Task (experiment). The status of the cloned Task is ``Draft``
and modifiable.
Use this method to manage experiments and for autoML.
:param source_task: The Task to clone. Specify a Task object or a Task Id. (Optional)
:type source_task: Task/str
:param str source_task: The Task to clone. Specify a Task object or a Task Id. (Optional)
:param str name: The name of the new cloned Task. (Optional)
:param str comment: A comment / description for the new cloned Task. (Optional)
:param str parent: The Id of the parent Task of the new Task.

View File

@@ -37,7 +37,7 @@ def range_validator(min_value, max_value):
"""
def _range_validator(instance, attribute, value):
if ((min_value is not None) and (value < min_value)) or \
((max_value is not None) and (value > max_value)):
((max_value is not None) and (value > max_value)):
raise ValueError("{} must be in range [{}, {}]".format(attribute.name, min_value, max_value))
return _range_validator
@@ -159,5 +159,3 @@ class TaskParameters(object):
"""
return task.connect(self)

View File

@@ -18,7 +18,7 @@ class PatchArgumentParser:
def add_subparsers(self, **kwargs):
if 'dest' not in kwargs:
if kwargs.get('title'):
kwargs['dest'] = '/'+kwargs['title']
kwargs['dest'] = '/' + kwargs['title']
else:
PatchArgumentParser._add_subparsers_counter += 1
kwargs['dest'] = '/subparser%d' % PatchArgumentParser._add_subparsers_counter
@@ -80,14 +80,16 @@ class PatchArgumentParser:
# if we do we need to patch the args, because there is no default subparser
if PY2:
import itertools
def _get_sub_parsers_defaults(subparser, prev=[]):
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(subparser, _SubParsersAction) else \
[subparser._actions]
sub_parsers_defaults = [[subparser]] if hasattr(subparser, 'default') and subparser.default else []
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(
subparser, _SubParsersAction) else [subparser._actions]
sub_parsers_defaults = [[subparser]] if hasattr(
subparser, 'default') and subparser.default else []
for actions in actions_grp:
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev)
for a in actions if isinstance(a, _SubParsersAction) and
hasattr(a, 'default') and a.default]
for a in actions if isinstance(a, _SubParsersAction) and
hasattr(a, 'default') and a.default]
return list(itertools.chain.from_iterable(sub_parsers_defaults))
sub_parsers_defaults = _get_sub_parsers_defaults(self)

View File

@@ -39,8 +39,8 @@ def get_percentage(config, key, required=True, default=None):
def get_human_size_default(config, key, default=None):
raw_value = config.get(key, default)
if raw_value is None:
return default
return parse_human_size(raw_value)
return parse_human_size(raw_value)

View File

@@ -118,4 +118,3 @@ class DeferredExecution(object):
return func(instance, *args, **kwargs)
return wrapper
return decorator

View File

@@ -39,6 +39,7 @@ class BlobsDict(dict):
"""
Overloading getitem so that the 'data' copy is only done when the dictionary item is accessed.
"""
def __init__(self, *args, **kwargs):
super(BlobsDict, self).__init__(*args, **kwargs)
@@ -60,6 +61,7 @@ class BlobsDict(dict):
class NestedBlobsDict(BlobsDict):
"""A dictionary that applies an arbitrary key-altering function
before accessing the keys."""
def __init__(self, *args, **kwargs):
super(NestedBlobsDict, self).__init__(*args, **kwargs)
@@ -96,14 +98,14 @@ class NestedBlobsDict(BlobsDict):
for key in cur_keys:
if isinstance(cur_dict[key], dict):
if len(path) > 0:
deep_keys.extend(self._keys(cur_dict[key], path+ '.' + key))
deep_keys.extend(self._keys(cur_dict[key], path + '.' + key))
else:
deep_keys.extend(self._keys(cur_dict[key], key))
else:
if len(path) > 0:
deep_keys.append(path + '.' + key)
else:
deep_keys.append( key)
deep_keys.append(key)
return deep_keys

View File

@@ -37,7 +37,7 @@ import threading
import string
## C Type mappings ##
## Enums
# Enums
_nvmlEnableState_t = c_uint
NVML_FEATURE_DISABLED = 0
NVML_FEATURE_ENABLED = 1
@@ -347,7 +347,7 @@ def _nvmlGetFunctionPointer(name):
libLoadLock.release()
## Alternative object
# Alternative object
# Allows the object to be printed
# Allows mismatched types to be assigned
# - like None when the Structure variant requires c_uint
@@ -379,7 +379,7 @@ def nvmlFriendlyObjectToStruct(obj, model):
return model
## Unit structures
# Unit structures
class struct_c_nvmlUnit_t(Structure):
pass # opaque handle
@@ -461,7 +461,7 @@ class c_nvmlUnitFanSpeeds_t(_PrintableStructure):
]
## Device structures
# Device structures
class struct_c_nvmlDevice_t(Structure):
pass # opaque handle
@@ -591,7 +591,7 @@ class c_nvmlViolationTime_t(_PrintableStructure):
]
## Event structures
# Event structures
class struct_c_nvmlEventSet_t(Structure):
pass # opaque handle
@@ -605,29 +605,30 @@ nvmlEventTypeXidCriticalError = 0x0000000000000008
nvmlEventTypeClock = 0x0000000000000010
nvmlEventTypeNone = 0x0000000000000000
nvmlEventTypeAll = (
nvmlEventTypeNone |
nvmlEventTypeSingleBitEccError |
nvmlEventTypeDoubleBitEccError |
nvmlEventTypePState |
nvmlEventTypeClock |
nvmlEventTypeXidCriticalError
nvmlEventTypeNone |
nvmlEventTypeSingleBitEccError |
nvmlEventTypeDoubleBitEccError |
nvmlEventTypePState |
nvmlEventTypeClock |
nvmlEventTypeXidCriticalError
)
## Clock Throttle Reasons defines
# Clock Throttle Reasons defines
nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001
nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002
nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting
# deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting
nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting
nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004
nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008
nvmlClocksThrottleReasonUnknown = 0x8000000000000000
nvmlClocksThrottleReasonNone = 0x0000000000000000
nvmlClocksThrottleReasonAll = (
nvmlClocksThrottleReasonNone |
nvmlClocksThrottleReasonGpuIdle |
nvmlClocksThrottleReasonApplicationsClocksSetting |
nvmlClocksThrottleReasonSwPowerCap |
nvmlClocksThrottleReasonHwSlowdown |
nvmlClocksThrottleReasonUnknown
nvmlClocksThrottleReasonNone |
nvmlClocksThrottleReasonGpuIdle |
nvmlClocksThrottleReasonApplicationsClocksSetting |
nvmlClocksThrottleReasonSwPowerCap |
nvmlClocksThrottleReasonHwSlowdown |
nvmlClocksThrottleReasonUnknown
)
@@ -785,7 +786,7 @@ def nvmlSystemGetHicVersion():
return hics
## Unit get functions
# Unit get functions
def nvmlUnitGetCount():
c_count = c_uint()
fn = _nvmlGetFunctionPointer("nvmlUnitGetCount")
@@ -865,7 +866,7 @@ def nvmlUnitGetDevices(unit):
return c_devices
## Device get functions
# Device get functions
def nvmlDeviceGetCount():
c_count = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2")
@@ -919,7 +920,7 @@ def nvmlDeviceGetName(handle):
def nvmlDeviceGetBoardId(handle):
c_id = c_uint();
c_id = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId")
ret = fn(handle, byref(c_id))
_nvmlCheckReturn(ret)
@@ -927,7 +928,7 @@ def nvmlDeviceGetBoardId(handle):
def nvmlDeviceGetMultiGpuBoard(handle):
c_multiGpu = c_uint();
c_multiGpu = c_uint()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard")
ret = fn(handle, byref(c_multiGpu))
_nvmlCheckReturn(ret)
@@ -1480,7 +1481,7 @@ def nvmlDeviceGetAutoBoostedClocksEnabled(handle):
# Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks
## Set functions
# Set functions
def nvmlUnitSetLedState(unit, color):
fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState")
ret = fn(unit, _nvmlLedColor_t(color))
@@ -1800,7 +1801,7 @@ def nvmlDeviceGetSamples(device, sampling_type, timeStamp):
c_sample_value_type = _nvmlValueType_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples")
## First Call gets the size
# First Call gets the size
ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), None)
# Stop if this fails
@@ -1819,7 +1820,7 @@ def nvmlDeviceGetViolationStatus(device, perfPolicyType):
c_violTime = c_nvmlViolationTime_t()
fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus")
## Invoke the method to get violation time
# Invoke the method to get violation time
ret = fn(device, c_perfPolicy_type, byref(c_violTime))
_nvmlCheckReturn(ret)
return c_violTime

View File

@@ -148,4 +148,3 @@ elif os.name == 'posix': # pragma: no cover
else: # pragma: no cover
raise RuntimeError('PortaLocker only defined for nt and posix platforms')

View File

@@ -218,6 +218,7 @@ class RLock(Lock):
can be acquired multiple times. When the corresponding number of release()
calls are made the lock will finally release the underlying file lock.
"""
def __init__(
self, filename, mode='a', timeout=DEFAULT_TIMEOUT,
check_interval=DEFAULT_CHECK_INTERVAL, fail_when_locked=False,

View File

@@ -1 +0,0 @@

View File

@@ -56,7 +56,7 @@ def project_import_modules(project_path, ignores):
continue
# Hack detect if this is a virtual-env folder, if so add it to the uignore list
if set(dirnames) == venv_subdirs:
ignore_absolute.append(Path(dirpath).as_posix()+os.sep)
ignore_absolute.append(Path(dirpath).as_posix() + os.sep)
continue
py_files = list()
@@ -132,7 +132,7 @@ class ImportChecker(object):
level -= 1
mod_name = ''
for alias in node.names:
name = level*'.' + mod_name + '.' + alias.name
name = level * '.' + mod_name + '.' + alias.name
self._modules.add(name, self._fpath, node.lineno + self._lineno)
if try_:
self._try_imports.add(name)

View File

@@ -92,7 +92,7 @@ class Archive(object):
def is_safe(self, filename):
return not (filename.startswith(("/", "\\")) or
(len(filename) > 1 and filename[1] == ":" and
filename[0] in string.ascii_letter) or
filename[0] in string.ascii_letter) or
re.search(r"[.][.][/\\]", filename))
def __enter__(self):

View File

@@ -102,11 +102,11 @@ def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=
for s in series:
# if we need to down-sample, use low-pass average filter and sampling
if s.data.size >= base_size:
budget = int(leftover * s.data.size/(total_size-baseused_size))
budget = int(leftover * s.data.size / (total_size - baseused_size))
step = int(np.ceil(s.data.size / float(budget)))
x = s.data[:, 0][::-step][::-1]
y = s.data[:, 1]
y_low_pass = np.convolve(y, np.ones(shape=(step,), dtype=y.dtype)/float(step), mode='same')
y_low_pass = np.convolve(y, np.ones(shape=(step,), dtype=y.dtype) / float(step), mode='same')
y = y_low_pass[::-step][::-1]
s.data = np.array([x, y], dtype=s.data.dtype).T
@@ -186,7 +186,8 @@ def create_3d_scatter_series(np_row_wise, title="Scatter", series_name="Series",
:return:
"""
if not plotly_obj:
plotly_obj = plotly_scatter3d_layout_dict(title=title, xaxis_title=xtitle, yaxis_title=ytitle, zaxis_title=ztitle)
plotly_obj = plotly_scatter3d_layout_dict(
title=title, xaxis_title=xtitle, yaxis_title=ytitle, zaxis_title=ztitle)
assert np_row_wise.ndim == 2, "Expected a 2D numpy array"
assert np_row_wise.shape[1] == 3, "Expected three columns X/Y/Z e.g. [(x0,y0,z0), (x1,y1,z1) ...]"
@@ -282,7 +283,7 @@ def create_3d_surface(np_value_matrix, title="3D Surface", xlabels=None, ylabels
"yaxis": {
"title": ytitle,
"showgrid": False,
"nticks": 10,
"nticks": 10,
"ticktext": ylabels,
"tickvals": list(range(len(ylabels))) if ylabels else ylabels,
},

View File

@@ -74,7 +74,7 @@ class ProxyDictPreWrite(dict):
return key_value
def _nested_callback(self, prefix, key_value):
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))
return self._set_callback((prefix + '.' + key_value[0], key_value[1],))
def flatten_dictionary(a_dict, prefix=''):
@@ -84,15 +84,15 @@ def flatten_dictionary(a_dict, prefix=''):
for k, v in a_dict.items():
k = str(k)
if isinstance(v, (float, int, bool, six.string_types)):
flat_dict[prefix+k] = v
flat_dict[prefix + k] = v
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
flat_dict[prefix+k] = v
flat_dict[prefix + k] = v
elif isinstance(v, dict):
flat_dict.update(flatten_dictionary(v, prefix=prefix+k+sep))
flat_dict.update(flatten_dictionary(v, prefix=prefix + k + sep))
else:
# this is a mixture of list and dict, or any other object,
# leave it as is, we have nothing to do with it.
flat_dict[prefix+k] = v
flat_dict[prefix + k] = v
return flat_dict
@@ -102,15 +102,15 @@ def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
for k, v in a_dict.items():
k = str(k)
if isinstance(v, (float, int, bool, six.string_types)):
a_dict[k] = flat_dict.get(prefix+k, v)
a_dict[k] = flat_dict.get(prefix + k, v)
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
a_dict[k] = flat_dict.get(prefix+k, v)
a_dict[k] = flat_dict.get(prefix + k, v)
elif isinstance(v, dict):
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix+k+sep) or v
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix + k + sep) or v
else:
# this is a mixture of list and dict, or any other object,
# leave it as is, we have nothing to do with it.
a_dict[k] = flat_dict.get(prefix+k, v)
a_dict[k] = flat_dict.get(prefix + k, v)
return a_dict
@@ -123,7 +123,7 @@ def naive_nested_from_flat_dictionary(flat_dict, sep='/'):
bucket[0][1] if (len(bucket) == 1 and sub_prefix == bucket[0][0])
else naive_nested_from_flat_dictionary(
{
k[len(sub_prefix)+1:]: v
k[len(sub_prefix) + 1:]: v
for k, v in bucket
if len(k) > len(sub_prefix)
}

View File

@@ -366,7 +366,7 @@ class ConfigParser(object):
null_expr = Keyword("null", caseless=True).setParseAction(replaceWith(NoneValue()))
# key = QuotedString('"', escChar='\\', unquoteResults=False) | Word(alphanums + alphas8bit + '._- /')
key = QuotedString('"', escChar='\\', unquoteResults=False) | \
Word("0123456789.").setParseAction(safe_convert_number) | Word(alphanums + alphas8bit + '._- /')
Word("0123456789.").setParseAction(safe_convert_number) | Word(alphanums + alphas8bit + '._- /')
eol = Word('\n\r').suppress()
eol_comma = Word('\n\r,').suppress()
@@ -390,13 +390,15 @@ class ConfigParser(object):
# line1 \
# line2 \
# so a backslash precedes the \n
unquoted_string = Regex(r'(?:[^^`+?!@*&"\[\{\s\]\}#,=\$\\]|\\.)+[ \t]*', re.UNICODE).setParseAction(unescape_string)
unquoted_string = Regex(r'(?:[^^`+?!@*&"\[\{\s\]\}#,=\$\\]|\\.)+[ \t]*',
re.UNICODE).setParseAction(unescape_string)
substitution_expr = Regex(r'[ \t]*\$\{[^\}]+\}[ \t]*').setParseAction(create_substitution)
string_expr = multiline_string | quoted_string | unquoted_string
value_expr = period_expr | number_expr | true_expr | false_expr | null_expr | string_expr
include_content = (quoted_string | ((Keyword('url') | Keyword('file')) - Literal('(').suppress() - quoted_string - Literal(')').suppress()))
include_content = (quoted_string | ((Keyword('url') | Keyword(
'file')) - Literal('(').suppress() - quoted_string - Literal(')').suppress()))
include_expr = (
Keyword("include", caseless=True).suppress() + (
include_content | (
@@ -408,33 +410,34 @@ class ConfigParser(object):
root_dict_expr = Forward()
dict_expr = Forward()
list_expr = Forward()
multi_value_expr = ZeroOrMore(comment_eol | include_expr | substitution_expr | dict_expr | list_expr | value_expr | (Literal(
'\\') - eol).suppress())
multi_value_expr = ZeroOrMore(comment_eol | include_expr | substitution_expr |
dict_expr | list_expr | value_expr | (Literal('\\') - eol).suppress())
# for a dictionary : or = is optional
# last zeroOrMore is because we can have t = {a:4} {b: 6} {c: 7} which is dictionary concatenation
inside_dict_expr = ConfigTreeParser(ZeroOrMore(comment_eol | include_expr | assign_expr | eol_comma))
inside_root_dict_expr = ConfigTreeParser(ZeroOrMore(comment_eol | include_expr | assign_expr | eol_comma), root=True)
inside_root_dict_expr = ConfigTreeParser(ZeroOrMore(
comment_eol | include_expr | assign_expr | eol_comma), root=True)
dict_expr << Suppress('{') - inside_dict_expr - Suppress('}')
root_dict_expr << Suppress('{') - inside_root_dict_expr - Suppress('}')
list_entry = ConcatenatedValueParser(multi_value_expr)
list_expr << Suppress('[') - ListParser(list_entry - ZeroOrMore(eol_comma - list_entry)) - Suppress(']')
# special case when we have a value assignment where the string can potentially be the remainder of the line
assign_expr << Group(
key - ZeroOrMore(comment_no_comma_eol) - (dict_expr | (Literal('=') | Literal(':') | Literal('+=')) - ZeroOrMore(
comment_no_comma_eol) - ConcatenatedValueParser(multi_value_expr))
)
assign_expr << Group(key - ZeroOrMore(comment_no_comma_eol) -
(dict_expr | (Literal('=') | Literal(':') | Literal('+=')) -
ZeroOrMore(comment_no_comma_eol) - ConcatenatedValueParser(multi_value_expr)))
# the file can be { ... } where {} can be omitted or []
config_expr = ZeroOrMore(comment_eol | eol) + (list_expr | root_dict_expr | inside_root_dict_expr) + ZeroOrMore(
comment_eol | eol_comma)
config_expr = ZeroOrMore(comment_eol | eol) + (list_expr | root_dict_expr |
inside_root_dict_expr) + ZeroOrMore(comment_eol | eol_comma)
config = config_expr.parseString(content, parseAll=True)[0]
if resolve:
allow_unresolved = resolve and unresolved_value is not DEFAULT_SUBSTITUTION and unresolved_value is not MANDATORY_SUBSTITUTION
has_unresolved = cls.resolve_substitutions(config, allow_unresolved)
if has_unresolved and unresolved_value is MANDATORY_SUBSTITUTION:
raise ConfigSubstitutionException('resolve cannot be set to True and unresolved_value to MANDATORY_SUBSTITUTION')
raise ConfigSubstitutionException(
'resolve cannot be set to True and unresolved_value to MANDATORY_SUBSTITUTION')
if unresolved_value is not NO_SUBSTITUTION and unresolved_value is not DEFAULT_SUBSTITUTION:
cls.unresolve_substitutions_to_value(config, unresolved_value)
@@ -485,14 +488,16 @@ class ConfigParser(object):
if len(prop_path) > 1 and config.get(substitution.variable, None) is not None:
continue # If value is present in latest version, don't do anything
if prop_path[0] == key:
if isinstance(previous_item, ConfigValues) and not accept_unresolved: # We hit a dead end, we cannot evaluate
if isinstance(
previous_item, ConfigValues) and not accept_unresolved: # We hit a dead end, we cannot evaluate
raise ConfigSubstitutionException(
"Property {variable} cannot be substituted. Check for cycles.".format(
variable=substitution.variable
)
)
else:
value = previous_item if len(prop_path) == 1 else previous_item.get(".".join(prop_path[1:]))
value = previous_item if len(
prop_path) == 1 else previous_item.get(".".join(prop_path[1:]))
_, _, current_item = cls._do_substitute(substitution, value)
previous_item = current_item
@@ -632,7 +637,8 @@ class ConfigParser(object):
# self resolution, backtrack
resolved_value = substitution.parent.overriden_value
unresolved, new_substitutions, result = cls._do_substitute(substitution, resolved_value, is_optional_resolved)
unresolved, new_substitutions, result = cls._do_substitute(
substitution, resolved_value, is_optional_resolved)
any_unresolved = unresolved or any_unresolved
substitutions.extend(new_substitutions)
if not isinstance(result, ConfigValues):

View File

@@ -148,7 +148,8 @@ class ConfigTree(OrderedDict):
if elt is UndefinedKey:
if default is UndefinedKey:
raise ConfigMissingException(u"No configuration setting found for key {key}".format(key='.'.join(key_path[:key_index + 1])))
raise ConfigMissingException(u"No configuration setting found for key {key}".format(
key='.'.join(key_path[: key_index + 1])))
else:
return default
@@ -184,7 +185,9 @@ class ConfigTree(OrderedDict):
return [string]
special_characters = '$}[]:=+#`^?!@*&.'
tokens = re.findall(r'"[^"]+"|[^{special_characters}]+'.format(special_characters=re.escape(special_characters)), string)
tokens = re.findall(
r'"[^"]+"|[^{special_characters}]+'.format(special_characters=re.escape(special_characters)),
string)
def contains_special_character(token):
return any((c in special_characters) for c in token)

View File

@@ -20,7 +20,7 @@ except NameError:
class HOCONConverter(object):
_number_re = r'[+-]?(\d*\.\d+|\d+(\.\d+)?)([eE][+\-]?\d+)?(?=$|[ \t]*([\$\}\],#\n\r]|//))'
_number_re_matcher = re.compile(_number_re)
@classmethod
def to_json(cls, config, compact=False, indent=2, level=0):
"""Convert HOCON input into a JSON output
@@ -150,7 +150,7 @@ class HOCONConverter(object):
new_line = cls.to_hocon(item, compact, indent, level + 1)
lines += new_line
if '\n' in new_line or len(lines) - base_len > 80:
if i < len(config)-1:
if i < len(config) - 1:
lines += ',\n{indent}'.format(indent=''.rjust(level * indent, ' '))
base_len = len(lines)
skip_comma = True

View File

@@ -45,7 +45,7 @@ class ResourceMonitor(object):
else: # if running_remotely():
try:
active_gpus = os.environ.get('NVIDIA_VISIBLE_DEVICES', '') or \
os.environ.get('CUDA_VISIBLE_DEVICES', '')
os.environ.get('CUDA_VISIBLE_DEVICES', '')
if active_gpus:
self._active_gpus = [int(g.strip()) for g in active_gpus.split(',')]
except Exception:
@@ -150,7 +150,7 @@ class ResourceMonitor(object):
try:
title = self._title_gpu if k.startswith('gpu_') else self._title_machine
# 3 points after the dot
value = round(v*1000) / 1000.
value = round(v * 1000) / 1000.
self._task.get_logger().report_scalar(title=title, series=k, iteration=iteration, value=value)
except Exception:
pass
@@ -178,7 +178,7 @@ class ResourceMonitor(object):
return self._num_readouts
def _get_average_readouts(self):
average_readouts = dict((k, v/float(self._num_readouts)) for k, v in self._readouts.items())
average_readouts = dict((k, v / float(self._num_readouts)) for k, v in self._readouts.items())
return average_readouts
def _clear_readouts(self):
@@ -205,7 +205,7 @@ class ResourceMonitor(object):
self._get_process_used_memory() if self._process_info else virtual_memory.used) / 1024
stats["memory_free_gb"] = bytes_to_megabytes(virtual_memory.available) / 1024
disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent
stats["disk_free_percent"] = 100.0-disk_use_percentage
stats["disk_free_percent"] = 100.0 - disk_use_percentage
with warnings.catch_warnings():
if logging.root.level > logging.DEBUG: # If the logging level is bigger than debug, ignore
# psutil.sensors_temperatures warnings

View File

@@ -6,6 +6,7 @@ try:
except Exception:
np = None
def make_deterministic(seed=1337, cudnn_deterministic=False):
"""
Ensure deterministic behavior across PyTorch using the provided random seed.