mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix documentation and layout (PEP8)
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from time import sleep, time
|
||||
from ..parameters import DiscreteParameterRange, UniformParameterRange, RandomSeed, UniformIntegerParameterRange
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..optimization import Objective, SearchStrategy
|
||||
from ..parameters import (
|
||||
DiscreteParameterRange, UniformParameterRange, RandomSeed, UniformIntegerParameterRange, Parameter, )
|
||||
from ...task import Task
|
||||
|
||||
try:
|
||||
@@ -9,16 +12,27 @@ try:
|
||||
import hpbandster.core.nameserver as hpns
|
||||
import ConfigSpace as CS
|
||||
import ConfigSpace.hyperparameters as CSH
|
||||
|
||||
Task.add_requirements('hpbandster')
|
||||
except ImportError:
|
||||
raise ValueError("OptimizerBOHB requires 'hpbandster' package, it was not found\n"
|
||||
"install with: pip install hpbandster")
|
||||
|
||||
|
||||
class TrainsBandsterWorker(Worker):
|
||||
def __init__(self, *args, optimizer, base_task_id, queue_name, objective,
|
||||
sleep_interval=0, budget_iteration_scale=1., **kwargs):
|
||||
super(TrainsBandsterWorker, self).__init__(*args, **kwargs)
|
||||
class _TrainsBandsterWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
*args, # type: Any
|
||||
optimizer, # type: OptimizerBOHB
|
||||
base_task_id, # type: str
|
||||
queue_name, # type: str
|
||||
objective, # type: Objective
|
||||
sleep_interval=0, # type: float
|
||||
budget_iteration_scale=1., # type: float
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> _TrainsBandsterWorker
|
||||
super(_TrainsBandsterWorker, self).__init__(*args, **kwargs)
|
||||
self.optimizer = optimizer
|
||||
self.base_task_id = base_task_id
|
||||
self.queue_name = queue_name
|
||||
@@ -28,6 +42,7 @@ class TrainsBandsterWorker(Worker):
|
||||
self._current_job = None
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
# type: (dict, float, Any) -> dict
|
||||
"""
|
||||
Simple example for a compute function
|
||||
The loss is just a the config + some noise (that decreases with the budget)
|
||||
@@ -43,10 +58,12 @@ class TrainsBandsterWorker(Worker):
|
||||
'info' (dict)
|
||||
"""
|
||||
self._current_job = self.optimizer.helper_create_job(self.base_task_id, parameter_override=config)
|
||||
# noinspection PyProtectedMember
|
||||
self.optimizer._current_jobs.append(self._current_job)
|
||||
self._current_job.launch(self.queue_name)
|
||||
iteration_value = None
|
||||
while not self._current_job.is_stopped():
|
||||
# noinspection PyProtectedMember
|
||||
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(self._current_job)
|
||||
if iteration_value and iteration_value[0] >= self.budget_iteration_scale * budget:
|
||||
self._current_job.abort()
|
||||
@@ -56,33 +73,65 @@ class TrainsBandsterWorker(Worker):
|
||||
result = {
|
||||
# this is the a mandatory field to run hyperband
|
||||
# remember: HpBandSter always minimizes!
|
||||
'loss': float(self.objective.get_normalized_objective(self._current_job)*-1.),
|
||||
'loss': float(self.objective.get_normalized_objective(self._current_job) * -1.),
|
||||
# can be used for any user-defined information - also mandatory
|
||||
'info': self._current_job.task_id()
|
||||
}
|
||||
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value))
|
||||
# noinspection PyProtectedMember
|
||||
self.optimizer._current_jobs.remove(self._current_job)
|
||||
return result
|
||||
|
||||
|
||||
class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
def __init__(self, base_task_id, hyper_parameters, objective_metric,
|
||||
execution_queue, num_concurrent_workers, min_iteration_per_job, max_iteration_per_job, total_max_jobs,
|
||||
pool_period_min=2.0, max_job_execution_minutes=None, **bohb_kargs):
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
hyper_parameters, # type: Sequence[Parameter]
|
||||
objective_metric, # type: Objective
|
||||
execution_queue, # type: str
|
||||
num_concurrent_workers, # type: int
|
||||
total_max_jobs, # type: Optional[int]
|
||||
min_iteration_per_job, # type: Optional[int]
|
||||
max_iteration_per_job, # type: Optional[int]
|
||||
pool_period_min=2., # type: float
|
||||
max_job_execution_minutes=None, # type: Optional[float]
|
||||
**bohb_kwargs, # type: Any
|
||||
):
|
||||
# type: (...) -> OptimizerBOHB
|
||||
"""
|
||||
Initialize a search strategy optimizer
|
||||
Initialize a BOHB search strategy optimizer
|
||||
BOHB performs robust and efficient hyperparameter optimization at scale by combining
|
||||
the speed of Hyperband searches with the guidance and guarantees of convergence of Bayesian
|
||||
Optimization. Instead of sampling new configurations at random,
|
||||
BOHB uses kernel density estimators to select promising candidates.
|
||||
|
||||
For reference: ::
|
||||
|
||||
@InProceedings{falkner-icml-18,
|
||||
title = {{BOHB}: Robust and Efficient Hyperparameter Optimization at Scale},
|
||||
author = {Falkner, Stefan and Klein, Aaron and Hutter, Frank},
|
||||
booktitle = {Proceedings of the 35th International Conference on Machine Learning},
|
||||
pages = {1436--1445},
|
||||
year = {2018},
|
||||
}
|
||||
|
||||
:param str base_task_id: Task ID (str)
|
||||
:param list hyper_parameters: list of Parameter objects to optimize over
|
||||
:param Objective objective_metric: Objective metric to maximize / minimize
|
||||
:param str execution_queue: execution queue to use for launching Tasks (experiments).
|
||||
:param int num_concurrent_workers: Limit number of concurrent running machines
|
||||
:param float min_iteration_per_job: minimum number of iterations for a job to run.
|
||||
:param int num_concurrent_workers: Limit number of concurrent running Tasks (machines)
|
||||
:param int min_iteration_per_job: minimum number of iterations for a job to run.
|
||||
'iterations' are the reported iterations for the specified objective,
|
||||
not the maximum reported iteration of the Task.
|
||||
:param int max_iteration_per_job: number of iteration per job
|
||||
:param int total_max_jobs: total maximum job for the optimization process. Default None, unlimited
|
||||
'iterations' are the reported iterations for the specified objective,
|
||||
not the maximum reported iteration of the Task.
|
||||
:param int total_max_jobs: total maximum job for the optimization process.
|
||||
Must be provided in order to calculate the total budget for the optimization process.
|
||||
:param float pool_period_min: time in minutes between two consecutive pools
|
||||
:param float max_job_execution_minutes: maximum time per single job in minutes, if exceeded job is aborted
|
||||
:param ** bohb_kargs: arguments passed directly yo the BOHB object
|
||||
:param bohb_kwargs: arguments passed directly yo the BOHB object
|
||||
"""
|
||||
super(OptimizerBOHB, self).__init__(
|
||||
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
|
||||
@@ -91,16 +140,25 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
total_max_jobs=total_max_jobs)
|
||||
self._max_iteration_per_job = max_iteration_per_job
|
||||
self._min_iteration_per_job = min_iteration_per_job
|
||||
self._bohb_kwargs = bohb_kargs or {}
|
||||
self._bohb_kwargs = bohb_kwargs or {}
|
||||
self._param_iterator = None
|
||||
self._namespace = None
|
||||
self._bohb = None
|
||||
self._res = None
|
||||
|
||||
def set_optimization_args(self, eta=3, min_budget=None, max_budget=None,
|
||||
min_points_in_model=None, top_n_percent=15,
|
||||
num_samples=None, random_fraction=1/3., bandwidth_factor=3,
|
||||
min_bandwidth=1e-3):
|
||||
def set_optimization_args(
|
||||
self,
|
||||
eta=3, # type: float
|
||||
min_budget=None, # type: Optional[float]
|
||||
max_budget=None, # type: Optional[float]
|
||||
min_points_in_model=None, # type: Optional[int]
|
||||
top_n_percent=15, # type: Optional[int]
|
||||
num_samples=None, # type: Optional[int]
|
||||
random_fraction=1 / 3., # type: Optional[float]
|
||||
bandwidth_factor=3, # type: Optional[float]
|
||||
min_bandwidth=1e-3, # type: Optional[float]
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
Defaults copied from BOHB constructor, see details in BOHB.__init__
|
||||
|
||||
@@ -134,7 +192,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
max_budget : float (1)
|
||||
The largest budget to consider. Needs to be larger than min_budget!
|
||||
The budgets will be geometrically distributed
|
||||
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
|
||||
:math:`a^2 + b^2 = c^2 /sim /eta^k` for :math:`k/in [0, 1, ... , num/_subsets - 1]`.
|
||||
min_points_in_model: int (None)
|
||||
number of observations to start building a KDE. Default 'None' means
|
||||
dim+1, the bare minimum.
|
||||
@@ -166,9 +224,16 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
self._bohb_kwargs['min_bandwidth'] = min_bandwidth
|
||||
|
||||
def start(self):
|
||||
# Step 1: Start a nameserver
|
||||
# type: () -> ()
|
||||
"""
|
||||
Start the Optimizer controller function loop()
|
||||
If the calling process is stopped, the controller will stop as well.
|
||||
|
||||
Notice: This function returns only after optimization is completed! or stop() was called.
|
||||
"""
|
||||
# Step 1: Start a NameServer
|
||||
fake_run_id = 'OptimizerBOHB_{}'.format(time())
|
||||
self._namespace = hpns.NameServer(run_id=fake_run_id, host='127.0.0.1', port=None)
|
||||
self._namespace = hpns.NameServer(run_id=fake_run_id, host='127.0.0.1', port=0)
|
||||
self._namespace.start()
|
||||
|
||||
# we have to scale the budget to the iterations per job, otherwise numbers might be too high
|
||||
@@ -177,21 +242,22 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
# Step 2: Start the workers
|
||||
workers = []
|
||||
for i in range(self._num_concurrent_workers):
|
||||
w = TrainsBandsterWorker(optimizer=self,
|
||||
sleep_interval=int(self.pool_period_minutes*60),
|
||||
budget_iteration_scale=budget_iteration_scale,
|
||||
base_task_id=self._base_task_id,
|
||||
objective=self._objective_metric,
|
||||
queue_name=self._execution_queue,
|
||||
nameserver='127.0.0.1', run_id=fake_run_id, id=i)
|
||||
w = _TrainsBandsterWorker(
|
||||
optimizer=self,
|
||||
sleep_interval=int(self.pool_period_minutes * 60),
|
||||
budget_iteration_scale=budget_iteration_scale,
|
||||
base_task_id=self._base_task_id,
|
||||
objective=self._objective_metric,
|
||||
queue_name=self._execution_queue,
|
||||
nameserver='127.0.0.1', run_id=fake_run_id, id=i)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
# Step 3: Run an optimizer
|
||||
self._bohb = BOHB(configspace=self.convert_hyper_parameters_to_cs(),
|
||||
self._bohb = BOHB(configspace=self._convert_hyper_parameters_to_cs(),
|
||||
run_id=fake_run_id,
|
||||
num_samples=self.total_max_jobs,
|
||||
min_budget=float(self._min_iteration_per_job)/float(self._max_iteration_per_job),
|
||||
min_budget=float(self._min_iteration_per_job) / float(self._max_iteration_per_job),
|
||||
**self._bohb_kwargs)
|
||||
self._res = self._bohb.run(n_iterations=self.total_max_jobs, min_n_workers=self._num_concurrent_workers)
|
||||
|
||||
@@ -199,6 +265,11 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Stop the current running optimization loop,
|
||||
Called from a different thread than the start()
|
||||
"""
|
||||
# After the optimizer run, we must shutdown the master and the nameserver.
|
||||
self._bohb.shutdown(shutdown_workers=True)
|
||||
self._namespace.shutdown()
|
||||
@@ -216,13 +287,14 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
print('A total of {} unique configurations where sampled.'.format(len(id2config.keys())))
|
||||
print('A total of {} runs where executed.'.format(len(self._res.get_all_runs())))
|
||||
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
|
||||
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
|
||||
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
|
||||
print('Total budget corresponds to {:.1f} full function evaluations.'.format(
|
||||
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
|
||||
sum([r.budget for r in all_runs]) / self._bohb_kwargs.get('max_budget', 1.0)))
|
||||
print('The run took {:.1f} seconds to complete.'.format(
|
||||
all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))
|
||||
all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))
|
||||
|
||||
def convert_hyper_parameters_to_cs(self):
|
||||
def _convert_hyper_parameters_to_cs(self):
|
||||
# type: () -> CS.ConfigurationSpace
|
||||
cs = CS.ConfigurationSpace(seed=self._seed)
|
||||
for p in self._hyper_parameters:
|
||||
if isinstance(p, UniformParameterRange):
|
||||
|
||||
@@ -2,16 +2,26 @@ import hashlib
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
from time import time, sleep
|
||||
from typing import Optional, Mapping, Sequence, Any
|
||||
|
||||
from ..task import Task
|
||||
from ..backend_api.services import tasks as tasks_service
|
||||
from ..backend_api.services import events as events_service
|
||||
|
||||
|
||||
logger = getLogger('trains.automation.job')
|
||||
|
||||
|
||||
class TrainsJob(object):
|
||||
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, parent=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
parameter_override=None, # type: Optional[Mapping[str, str]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, str]]
|
||||
tags=None, # type: Optional[Sequence[str]]
|
||||
parent=None, # type: Optional[str]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> TrainsJob
|
||||
"""
|
||||
Create a new Task based in a base_task_id with a different set of parameters
|
||||
|
||||
@@ -32,11 +42,12 @@ class TrainsJob(object):
|
||||
if task_overrides:
|
||||
# todo: make sure it works
|
||||
# noinspection PyProtectedMember
|
||||
self.task._edit(task_overrides)
|
||||
self.task._edit(**task_overrides)
|
||||
self.task_started = False
|
||||
self._worker = None
|
||||
|
||||
def get_metric(self, title, series):
|
||||
# type: (str, str) -> (float, float, float)
|
||||
"""
|
||||
Retrieve a specific scalar metric from the running Task.
|
||||
|
||||
@@ -63,6 +74,7 @@ class TrainsJob(object):
|
||||
return tuple(response.response_data['tasks'][0]['last_metrics'][title][series][v] for v in values)
|
||||
|
||||
def launch(self, queue_name=None):
|
||||
# type: (str) -> ()
|
||||
"""
|
||||
Send Job for execution on the requested execution queue
|
||||
|
||||
@@ -74,6 +86,7 @@ class TrainsJob(object):
|
||||
logger.warning(ex)
|
||||
|
||||
def abort(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Abort currently running job (can be called multiple times)
|
||||
"""
|
||||
@@ -83,6 +96,7 @@ class TrainsJob(object):
|
||||
logger.warning(ex)
|
||||
|
||||
def elapsed(self):
|
||||
# type: () -> float
|
||||
"""
|
||||
Return the time in seconds since job started. Return -1 if job is still pending
|
||||
|
||||
@@ -94,6 +108,7 @@ class TrainsJob(object):
|
||||
return (datetime.now() - self.task.data.started).timestamp()
|
||||
|
||||
def iterations(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Return the last iteration value of the current job. -1 if job has not started yet
|
||||
|
||||
@@ -105,6 +120,7 @@ class TrainsJob(object):
|
||||
return self.task.get_last_iteration()
|
||||
|
||||
def task_id(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return the Task id.
|
||||
|
||||
@@ -113,6 +129,7 @@ class TrainsJob(object):
|
||||
return self.task.id
|
||||
|
||||
def status(self):
|
||||
# type: () -> Task.TaskStatusEnum
|
||||
"""
|
||||
Return the Job Task current status, see Task.TaskStatusEnum
|
||||
|
||||
@@ -121,6 +138,7 @@ class TrainsJob(object):
|
||||
return self.task.status
|
||||
|
||||
def wait(self, timeout=None, pool_period=30.):
|
||||
# type: (Optional[float], float) -> bool
|
||||
"""
|
||||
Wait until the task is fully executed (i.e. aborted/completed/failed)
|
||||
|
||||
@@ -129,7 +147,7 @@ class TrainsJob(object):
|
||||
:return bool: Return True is Task finished.
|
||||
"""
|
||||
tic = time()
|
||||
while timeout is None or time()-tic < timeout*60.:
|
||||
while timeout is None or time() - tic < timeout * 60.:
|
||||
if self.is_stopped():
|
||||
return True
|
||||
sleep(pool_period)
|
||||
@@ -137,6 +155,7 @@ class TrainsJob(object):
|
||||
return self.is_stopped()
|
||||
|
||||
def get_console_output(self, number_of_reports=1):
|
||||
# type: (int) -> Sequence[str]
|
||||
"""
|
||||
Return a list of console outputs reported by the Task.
|
||||
Returned console outputs are retrieved from the most updated console outputs.
|
||||
@@ -147,6 +166,7 @@ class TrainsJob(object):
|
||||
return self.task.get_reported_console_output(number_of_reports=number_of_reports)
|
||||
|
||||
def worker(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return the current worker id executing this Job. If job is pending, returns None
|
||||
|
||||
@@ -165,6 +185,7 @@ class TrainsJob(object):
|
||||
return self._worker
|
||||
|
||||
def is_running(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True if job is currently running (pending is considered False)
|
||||
|
||||
@@ -173,6 +194,7 @@ class TrainsJob(object):
|
||||
return self.task.status == Task.TaskStatusEnum.in_progress
|
||||
|
||||
def is_stopped(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True if job is has executed and is not any more
|
||||
|
||||
@@ -183,6 +205,7 @@ class TrainsJob(object):
|
||||
Task.TaskStatusEnum.failed, Task.TaskStatusEnum.published)
|
||||
|
||||
def is_pending(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True if job is waiting for execution
|
||||
|
||||
@@ -191,8 +214,20 @@ class TrainsJob(object):
|
||||
return self.task.status in (Task.TaskStatusEnum.queued, Task.TaskStatusEnum.created)
|
||||
|
||||
|
||||
class JobStub(object):
|
||||
def __init__(self, base_task_id, parameter_override=None, task_overrides=None, tags=None, **_):
|
||||
# noinspection PyMethodMayBeStatic, PyUnusedLocal
|
||||
class _JobStub(object):
|
||||
"""
|
||||
This is a Job Stub, use only for debugging
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
parameter_override=None, # type: Optional[Mapping[str, str]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, str]]
|
||||
tags=None, # type: Optional[Sequence[str]]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> _JobStub
|
||||
self.task = None
|
||||
self.base_task_id = base_task_id
|
||||
self.parameter_override = parameter_override
|
||||
@@ -202,14 +237,17 @@ class JobStub(object):
|
||||
self.task_started = None
|
||||
|
||||
def launch(self, queue_name=None):
|
||||
# type: (str) -> ()
|
||||
self.iteration = 0
|
||||
self.task_started = time()
|
||||
print('launching', self.parameter_override, 'in', queue_name)
|
||||
|
||||
def abort(self):
|
||||
# type: () -> ()
|
||||
self.task_started = -1
|
||||
|
||||
def elapsed(self):
|
||||
# type: () -> float
|
||||
"""
|
||||
Return the time in seconds since job started. Return -1 if job is still pending
|
||||
|
||||
@@ -220,6 +258,7 @@ class JobStub(object):
|
||||
return time() - self.task_started
|
||||
|
||||
def iterations(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Return the last iteration value of the current job. -1 if job has not started yet
|
||||
:return int: Task last iteration
|
||||
@@ -229,6 +268,7 @@ class JobStub(object):
|
||||
return self.iteration
|
||||
|
||||
def get_metric(self, title, series):
|
||||
# type: (str, str) -> (float, float, float)
|
||||
"""
|
||||
Retrieve a specific scalar metric from the running Task.
|
||||
|
||||
@@ -239,15 +279,19 @@ class JobStub(object):
|
||||
return 0, 1.0, 0.123
|
||||
|
||||
def task_id(self):
|
||||
# type: () -> str
|
||||
return 'stub'
|
||||
|
||||
def worker(self):
|
||||
# type: () -> ()
|
||||
return None
|
||||
|
||||
def status(self):
|
||||
# type: () -> str
|
||||
return 'in_progress'
|
||||
|
||||
def wait(self, timeout=None, pool_period=30.):
|
||||
# type: (Optional[float], float) -> bool
|
||||
"""
|
||||
Wait for the task to be processed (i.e. aborted/completed/failed)
|
||||
|
||||
@@ -258,6 +302,7 @@ class JobStub(object):
|
||||
return True
|
||||
|
||||
def get_console_output(self, number_of_reports=1):
|
||||
# type: (int) -> Sequence[str]
|
||||
"""
|
||||
Return a list of console outputs reported by the Task.
|
||||
Returned console outputs are retrieved from the most updated console outputs.
|
||||
@@ -269,11 +314,13 @@ class JobStub(object):
|
||||
return []
|
||||
|
||||
def is_running(self):
|
||||
# type: () -> bool
|
||||
return self.task_started is not None and self.task_started > 0
|
||||
|
||||
def is_stopped(self):
|
||||
# type: () -> bool
|
||||
return self.task_started is not None and self.task_started < 0
|
||||
|
||||
def is_pending(self):
|
||||
# type: () -> bool
|
||||
return self.task_started is None
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ from itertools import product
|
||||
from logging import getLogger
|
||||
from threading import Thread, Event
|
||||
from time import time
|
||||
from typing import Union, Any, Sequence, Optional, Mapping, Callable
|
||||
|
||||
from .job import TrainsJob
|
||||
from ..task import Task
|
||||
from .parameters import Parameter
|
||||
from ..task import Task
|
||||
|
||||
logger = getLogger('trains.automation.optimization')
|
||||
|
||||
@@ -32,6 +33,7 @@ class Objective(object):
|
||||
"""
|
||||
|
||||
def __init__(self, title, series, order='max', extremum=False):
|
||||
# type: (str, str, Union['max', 'min'], bool) -> Objective
|
||||
"""
|
||||
Construct objective object that will return the scalar value for a specific task ID
|
||||
|
||||
@@ -45,12 +47,12 @@ class Objective(object):
|
||||
self.series = series
|
||||
assert order in ('min', 'max',)
|
||||
# normalize value so we always look for the highest objective value
|
||||
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') or \
|
||||
(not isinstance(order, str) and order < 0) else +1
|
||||
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') else +1
|
||||
self._metric = None
|
||||
self.extremum = extremum
|
||||
|
||||
def get_objective(self, task_id):
|
||||
# type: (Union[str, Task, TrainsJob]) -> Optional[float]
|
||||
"""
|
||||
Return a specific task scalar value based on the objective settings (title/series)
|
||||
|
||||
@@ -58,7 +60,7 @@ class Objective(object):
|
||||
:return float: scalar value
|
||||
"""
|
||||
# create self._metric
|
||||
self.get_last_metrics_encode_field()
|
||||
self._get_last_metrics_encode_field()
|
||||
|
||||
if isinstance(task_id, Task):
|
||||
task_id = task_id.id
|
||||
@@ -85,6 +87,7 @@ class Objective(object):
|
||||
return None
|
||||
|
||||
def get_current_raw_objective(self, task):
|
||||
# type: (Union[TrainsJob, Task]) -> (int, float)
|
||||
"""
|
||||
Return the current raw value (without sign normalization) of the objective
|
||||
|
||||
@@ -103,12 +106,14 @@ class Objective(object):
|
||||
# todo: replace with more efficient code
|
||||
scalars = task.get_reported_scalars()
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return scalars[self.title][self.series]['x'][-1], scalars[self.title][self.series]['y'][-1]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_objective_sign(self):
|
||||
# type: () -> float
|
||||
"""
|
||||
Return the sign of the objective (i.e. +1 if maximizing, and -1 if minimizing)
|
||||
|
||||
@@ -117,6 +122,7 @@ class Objective(object):
|
||||
return self.sign
|
||||
|
||||
def get_objective_metric(self):
|
||||
# type: () -> (str, str)
|
||||
"""
|
||||
Return the metric title, series pair of the objective
|
||||
|
||||
@@ -125,6 +131,7 @@ class Objective(object):
|
||||
return self.title, self.series
|
||||
|
||||
def get_normalized_objective(self, task_id):
|
||||
# type: (Union[str, Task, TrainsJob]) -> Optional[float]
|
||||
"""
|
||||
Return a normalized task scalar value based on the objective settings (title/series)
|
||||
I.e. objective is always to maximize the returned value
|
||||
@@ -138,7 +145,13 @@ class Objective(object):
|
||||
# normalize value so we always look for the highest objective value
|
||||
return self.sign * objective
|
||||
|
||||
def get_last_metrics_encode_field(self):
|
||||
def _get_last_metrics_encode_field(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return encoded representation of title/series metric
|
||||
|
||||
:return str: string representing the objective title/series
|
||||
"""
|
||||
if not self._metric:
|
||||
title = hashlib.md5(str(self.title).encode('utf-8')).hexdigest()
|
||||
series = hashlib.md5(str(self.series).encode('utf-8')).hexdigest()
|
||||
@@ -153,11 +166,21 @@ class SearchStrategy(object):
|
||||
Base Search strategy class, inherit to implement your custom strategy
|
||||
"""
|
||||
_tag = 'optimization'
|
||||
_job_class = TrainsJob
|
||||
_job_class = TrainsJob # type: TrainsJob
|
||||
|
||||
def __init__(self, base_task_id, hyper_parameters, objective_metric,
|
||||
execution_queue, num_concurrent_workers, pool_period_min=2.,
|
||||
max_job_execution_minutes=None, total_max_jobs=None, **_):
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
hyper_parameters, # type: Sequence[Parameter]
|
||||
objective_metric, # type: Objective
|
||||
execution_queue, # type: str
|
||||
num_concurrent_workers, # type: int
|
||||
pool_period_min=2., # type: float
|
||||
max_job_execution_minutes=None, # type: Optional[float]
|
||||
total_max_jobs=None, # type: Optional[int]
|
||||
**_ # type: Any
|
||||
):
|
||||
# type: (...) -> SearchStrategy
|
||||
"""
|
||||
Initialize a search strategy optimizer
|
||||
|
||||
@@ -189,6 +212,7 @@ class SearchStrategy(object):
|
||||
self._validate_base_task()
|
||||
|
||||
def start(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Start the Optimizer controller function loop()
|
||||
If the calling process is stopped, the controller will stop as well.
|
||||
@@ -205,6 +229,7 @@ class SearchStrategy(object):
|
||||
counter += 1
|
||||
|
||||
def stop(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Stop the current running optimization loop,
|
||||
Called from a different thread than the start()
|
||||
@@ -212,6 +237,7 @@ class SearchStrategy(object):
|
||||
self._stop_event.set()
|
||||
|
||||
def process_step(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Abstract helper function, not a must to implement, default use in start default implementation
|
||||
Main optimization loop, called from the daemon thread created by start()
|
||||
@@ -248,6 +274,7 @@ class SearchStrategy(object):
|
||||
return bool(self._current_jobs)
|
||||
|
||||
def create_job(self):
|
||||
# type: () -> Optional[TrainsJob]
|
||||
"""
|
||||
Abstract helper function, not a must to implement, default use in process_step default implementation
|
||||
Create a new job if needed. return the newly created job.
|
||||
@@ -258,17 +285,19 @@ class SearchStrategy(object):
|
||||
return None
|
||||
|
||||
def monitor_job(self, job):
|
||||
# type: (TrainsJob) -> bool
|
||||
"""
|
||||
Abstract helper function, not a must to implement, default use in process_step default implementation
|
||||
Check if the job needs to be aborted or already completed
|
||||
if return False, the job was aborted / completed, and should be taken off the current job list
|
||||
|
||||
:param TrainsJob job: a TrainsJob object to monitor
|
||||
:return: boolean, If False, job is no longer relevant
|
||||
:return bool: If False, job is no longer relevant
|
||||
"""
|
||||
return not job.is_stopped()
|
||||
|
||||
def get_running_jobs(self):
|
||||
# type: () -> Sequence[TrainsJob]
|
||||
"""
|
||||
Return the current running TrainsJobs
|
||||
|
||||
@@ -277,28 +306,32 @@ class SearchStrategy(object):
|
||||
return self._current_jobs
|
||||
|
||||
def get_created_jobs_ids(self):
|
||||
# type: () -> Mapping[str, dict]
|
||||
"""
|
||||
Return a task ids dict created ny this optimizer until now, including completed and running jobs.
|
||||
The values of the returned dict are the parameters used in the specific job
|
||||
|
||||
:return dict(str): dict of task ids (str) as keys, and their parameters dict as value
|
||||
:return dict: dict of task ids (str) as keys, and their parameters dict as value
|
||||
"""
|
||||
return self._created_jobs_ids
|
||||
|
||||
def get_top_experiments(self, top_k):
|
||||
# type: (int) -> Sequence[Task]
|
||||
"""
|
||||
Return a list of Tasks of the top performing experiments, based on the controller Objective object
|
||||
|
||||
:param int top_k: Number of Tasks (experiments) to return
|
||||
:return list: List of Task objects, ordered by performance, where index 0 is the best performing Task.
|
||||
"""
|
||||
# metric_filter =
|
||||
top_tasks = self._get_child_tasks(parent_task_id=self._job_parent_id or self._base_task_id,
|
||||
order_by=self._objective_metric.get_last_metrics_encode_field(),
|
||||
additional_filters={'page_size': int(top_k), 'page': 0})
|
||||
# noinspection PyProtectedMember
|
||||
top_tasks = self._get_child_tasks(
|
||||
parent_task_id=self._job_parent_id or self._base_task_id,
|
||||
order_by=self._objective_metric._get_last_metrics_encode_field(),
|
||||
additional_filters={'page_size': int(top_k), 'page': 0})
|
||||
return top_tasks
|
||||
|
||||
def get_objective_metric(self):
|
||||
# type: () -> (str, str)
|
||||
"""
|
||||
Return the metric title, series pair of the objective
|
||||
|
||||
@@ -306,10 +339,19 @@ class SearchStrategy(object):
|
||||
"""
|
||||
return self._objective_metric.get_objective_metric()
|
||||
|
||||
def helper_create_job(self, base_task_id, parameter_override=None,
|
||||
task_overrides=None, tags=None, parent=None, **kwargs):
|
||||
def helper_create_job(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
parameter_override=None, # type: Optional[Mapping[str, str]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, str]]
|
||||
tags=None, # type: Optional[Sequence[str]]
|
||||
parent=None, # type: Optional[str]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> TrainsJob
|
||||
"""
|
||||
Create a Job using the specified arguments, TrainsJob for details
|
||||
|
||||
:return TrainsJob: Returns a newly created Job instance
|
||||
"""
|
||||
if parameter_override:
|
||||
@@ -334,6 +376,7 @@ class SearchStrategy(object):
|
||||
return new_job
|
||||
|
||||
def set_job_class(self, job_class):
|
||||
# type: (TrainsJob) -> ()
|
||||
"""
|
||||
Set the class to use for the helper_create_job function
|
||||
|
||||
@@ -342,6 +385,7 @@ class SearchStrategy(object):
|
||||
self._job_class = job_class
|
||||
|
||||
def set_job_default_parent(self, job_parent_task_id):
|
||||
# type: (str) -> ()
|
||||
"""
|
||||
Set the default parent for all Jobs created by the helper_create_job method
|
||||
:param str job_parent_task_id: Parent task id
|
||||
@@ -349,6 +393,7 @@ class SearchStrategy(object):
|
||||
self._job_parent_id = job_parent_task_id
|
||||
|
||||
def set_job_naming_scheme(self, naming_function):
|
||||
# type: (Optional[Callable[[str, dict], str]]) -> ()
|
||||
"""
|
||||
Set the function used to name a newly created job
|
||||
|
||||
@@ -357,6 +402,7 @@ class SearchStrategy(object):
|
||||
self._naming_function = naming_function
|
||||
|
||||
def _validate_base_task(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Check the base task exists and contains the requested objective metric and hyper parameters
|
||||
"""
|
||||
@@ -378,6 +424,7 @@ class SearchStrategy(object):
|
||||
self._objective_metric.get_objective_metric(), self._base_task_id))
|
||||
|
||||
def _get_task_project(self, parent_task_id):
|
||||
# type: (str) -> (Optional[str])
|
||||
if not parent_task_id:
|
||||
return
|
||||
if parent_task_id not in self._job_project:
|
||||
@@ -387,7 +434,14 @@ class SearchStrategy(object):
|
||||
return self._job_project.get(parent_task_id)
|
||||
|
||||
@classmethod
|
||||
def _get_child_tasks(cls, parent_task_id, status=None, order_by=None, additional_filters=None):
|
||||
def _get_child_tasks(
|
||||
cls,
|
||||
parent_task_id, # type: str
|
||||
status=None, # type: Optional[Task.TaskStatusEnum]
|
||||
order_by=None, # type: Optional[str]
|
||||
additional_filters=None # type: Optional[dict]
|
||||
):
|
||||
# type: (...) -> (Sequence[Task])
|
||||
"""
|
||||
Helper function, return a list of tasks tagged automl with specific status ordered by sort_field
|
||||
|
||||
@@ -401,7 +455,7 @@ class SearchStrategy(object):
|
||||
"execution.parameters.name"
|
||||
"updated"
|
||||
:param dict additional_filters: Additional task filters
|
||||
:return List(Task): List of Task objects
|
||||
:return list(Task): List of Task objects
|
||||
"""
|
||||
task_filter = {'parent': parent_task_id,
|
||||
# 'tags': [cls._tag],
|
||||
@@ -432,8 +486,19 @@ class GridSearch(SearchStrategy):
|
||||
Full grid sampling of every hyper-parameter combination
|
||||
"""
|
||||
|
||||
def __init__(self, base_task_id, hyper_parameters, objective_metric,
|
||||
execution_queue, num_concurrent_workers, pool_period_min=2.0, max_job_execution_minutes=None, **_):
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
hyper_parameters, # type: Sequence[Parameter]
|
||||
objective_metric, # type: Objective
|
||||
execution_queue, # type: str
|
||||
num_concurrent_workers, # type: int
|
||||
pool_period_min=2., # type: float
|
||||
max_job_execution_minutes=None, # type: Optional[float]
|
||||
total_max_jobs=None, # type: Optional[int]
|
||||
**_ # type: Any
|
||||
):
|
||||
# type: (...) -> GridSearch
|
||||
"""
|
||||
Initialize a grid search optimizer
|
||||
|
||||
@@ -444,14 +509,17 @@ class GridSearch(SearchStrategy):
|
||||
:param int num_concurrent_workers: Limit number of concurrent running machines
|
||||
:param float pool_period_min: time in minutes between two consecutive pools
|
||||
:param float max_job_execution_minutes: maximum time per single job in minutes, if exceeded job is aborted
|
||||
:param int total_max_jobs: total maximum job for the optimization process. Default None, unlimited
|
||||
"""
|
||||
super(GridSearch, self).__init__(
|
||||
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
|
||||
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
|
||||
pool_period_min=pool_period_min, max_job_execution_minutes=max_job_execution_minutes)
|
||||
pool_period_min=pool_period_min, max_job_execution_minutes=max_job_execution_minutes,
|
||||
total_max_jobs=total_max_jobs, **_)
|
||||
self._param_iterator = None
|
||||
|
||||
def create_job(self):
|
||||
# type: () -> Optional[TrainsJob]
|
||||
"""
|
||||
Create a new job if needed. return the newly created job.
|
||||
If no job needs to be created, return None
|
||||
@@ -466,6 +534,7 @@ class GridSearch(SearchStrategy):
|
||||
return self.helper_create_job(base_task_id=self._base_task_id, parameter_override=parameters)
|
||||
|
||||
def monitor_job(self, job):
|
||||
# type: (TrainsJob) -> bool
|
||||
"""
|
||||
Check if the job needs to be aborted or already completed
|
||||
if return False, the job was aborted / completed, and should be taken off the current job list
|
||||
@@ -480,6 +549,7 @@ class GridSearch(SearchStrategy):
|
||||
return not job.is_stopped()
|
||||
|
||||
def _next_configuration(self):
|
||||
# type: () -> Mapping[str, str]
|
||||
def param_iterator_fn():
|
||||
hyper_params_values = [p.to_list() for p in self._hyper_parameters]
|
||||
for state in product(*hyper_params_values):
|
||||
@@ -499,9 +569,19 @@ class RandomSearch(SearchStrategy):
|
||||
# Number of already chosen random samples before assuming we covered the entire hyper-parameter space
|
||||
_hp_space_cover_samples = 42
|
||||
|
||||
def __init__(self, base_task_id, hyper_parameters, objective_metric,
|
||||
execution_queue, num_concurrent_workers, pool_period_min=2.0,
|
||||
max_job_execution_minutes=None, total_max_jobs=None, **_):
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
hyper_parameters, # type: Sequence[Parameter]
|
||||
objective_metric, # type: Objective
|
||||
execution_queue, # type: str
|
||||
num_concurrent_workers, # type: int
|
||||
pool_period_min=2., # type: float
|
||||
max_job_execution_minutes=None, # type: Optional[float]
|
||||
total_max_jobs=None, # type: Optional[int]
|
||||
**_ # type: Any
|
||||
):
|
||||
# type: (...) -> RandomSearch
|
||||
"""
|
||||
Initialize a random search optimizer
|
||||
|
||||
@@ -518,10 +598,11 @@ class RandomSearch(SearchStrategy):
|
||||
base_task_id=base_task_id, hyper_parameters=hyper_parameters, objective_metric=objective_metric,
|
||||
execution_queue=execution_queue, num_concurrent_workers=num_concurrent_workers,
|
||||
pool_period_min=pool_period_min,
|
||||
max_job_execution_minutes=max_job_execution_minutes, total_max_jobs=total_max_jobs)
|
||||
max_job_execution_minutes=max_job_execution_minutes, total_max_jobs=total_max_jobs, **_)
|
||||
self._hyper_parameters_collection = set()
|
||||
|
||||
def create_job(self):
|
||||
# type: () -> Optional[TrainsJob]
|
||||
"""
|
||||
Create a new job if needed. return the newly created job.
|
||||
If no job needs to be created, return None
|
||||
@@ -551,6 +632,7 @@ class RandomSearch(SearchStrategy):
|
||||
return self.helper_create_job(base_task_id=self._base_task_id, parameter_override=parameters)
|
||||
|
||||
def monitor_job(self, job):
|
||||
# type: (TrainsJob) -> bool
|
||||
"""
|
||||
Check if the job needs to be aborted or already completed
|
||||
if return False, the job was aborted / completed, and should be taken off the current job list
|
||||
@@ -572,18 +654,22 @@ class HyperParameterOptimizer(object):
|
||||
"""
|
||||
_tag = 'optimization'
|
||||
|
||||
def __init__(self, base_task_id,
|
||||
hyper_parameters,
|
||||
objective_metric_title,
|
||||
objective_metric_series,
|
||||
objective_metric_sign='min',
|
||||
optimizer_class=RandomSearch,
|
||||
max_number_of_concurrent_tasks=10,
|
||||
execution_queue='default',
|
||||
optimization_time_limit=None,
|
||||
auto_connect_task=True,
|
||||
always_create_task=False,
|
||||
**optimizer_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
hyper_parameters, # type: Sequence[Parameter]
|
||||
objective_metric_title, # type: str
|
||||
objective_metric_series, # type: str
|
||||
objective_metric_sign='min', # type: Union['min', 'max', 'min_global', 'max_global']
|
||||
optimizer_class=RandomSearch, # type: SearchStrategy
|
||||
max_number_of_concurrent_tasks=10, # type: int
|
||||
execution_queue='default', # type: str
|
||||
optimization_time_limit=None, # type: Optional[float]
|
||||
auto_connect_task=True, # type: bool
|
||||
always_create_task=False, # type: bool
|
||||
**optimizer_kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> HyperParameterOptimizer
|
||||
"""
|
||||
Create a new hyper-parameter controller. The newly created object will launch and monitor the new experiments.
|
||||
|
||||
@@ -702,6 +788,7 @@ class HyperParameterOptimizer(object):
|
||||
self.set_time_limit(in_minutes=opts['optimization_time_limit'])
|
||||
|
||||
def get_num_active_experiments(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Return the number of current active experiments
|
||||
|
||||
@@ -712,6 +799,7 @@ class HyperParameterOptimizer(object):
|
||||
return len(self.optimizer.get_running_jobs())
|
||||
|
||||
def get_active_experiments(self):
|
||||
# type: () -> Sequence[Task]
|
||||
"""
|
||||
Return a list of Tasks of the current active experiments
|
||||
|
||||
@@ -722,6 +810,7 @@ class HyperParameterOptimizer(object):
|
||||
return [j.task for j in self.optimizer.get_running_jobs()]
|
||||
|
||||
def start(self, job_complete_callback=None):
|
||||
# type: (Optional[Callable[[str, float, int, dict, str], None]]) -> bool
|
||||
"""
|
||||
Start the HyperParameterOptimizer controller.
|
||||
If the calling process is stopped, the controller will stop as well.
|
||||
@@ -755,6 +844,7 @@ class HyperParameterOptimizer(object):
|
||||
return True
|
||||
|
||||
def stop(self, timeout=None):
|
||||
# type: (Optional[float]) -> ()
|
||||
"""
|
||||
Stop the HyperParameterOptimizer controller and optimization thread,
|
||||
|
||||
@@ -770,7 +860,7 @@ class HyperParameterOptimizer(object):
|
||||
|
||||
# wait for optimizer thread
|
||||
if timeout is not None:
|
||||
_thread.join(timeout=timeout*60.)
|
||||
_thread.join(timeout=timeout * 60.)
|
||||
|
||||
# stop all running tasks:
|
||||
for j in self.optimizer.get_running_jobs():
|
||||
@@ -782,6 +872,7 @@ class HyperParameterOptimizer(object):
|
||||
self._thread_reporter.join()
|
||||
|
||||
def is_active(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True if the optimization procedure is still running
|
||||
Note, if the daemon thread has not yet started, it will still return True
|
||||
@@ -791,6 +882,7 @@ class HyperParameterOptimizer(object):
|
||||
return self._stop_event is None or self._thread is not None
|
||||
|
||||
def is_running(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True if the optimization controller is running
|
||||
|
||||
@@ -799,6 +891,7 @@ class HyperParameterOptimizer(object):
|
||||
return self._thread is not None
|
||||
|
||||
def wait(self, timeout=None):
|
||||
# type: (Optional[float]) -> bool
|
||||
"""
|
||||
Wait for the optimizer to finish.
|
||||
It will not stop the optimizer in any case. Call stop() to terminate the optimizer.
|
||||
@@ -825,6 +918,7 @@ class HyperParameterOptimizer(object):
|
||||
return True
|
||||
|
||||
def set_time_limit(self, in_minutes=None, specific_time=None):
|
||||
# type: (Optional[float], Optional[datetime]) -> ()
|
||||
"""
|
||||
Set a time limit for the HyperParameterOptimizer controller,
|
||||
i.e. if we reached the time limit, stop the optimization process
|
||||
@@ -835,9 +929,10 @@ class HyperParameterOptimizer(object):
|
||||
if specific_time:
|
||||
self.optimization_timeout = specific_time.timestamp()
|
||||
else:
|
||||
self.optimization_timeout = (in_minutes*60.) + time() if in_minutes else None
|
||||
self.optimization_timeout = (in_minutes * 60.) + time() if in_minutes else None
|
||||
|
||||
def get_time_limit(self):
|
||||
# type: () -> datetime
|
||||
"""
|
||||
Return the controller optimization time limit.
|
||||
|
||||
@@ -846,6 +941,7 @@ class HyperParameterOptimizer(object):
|
||||
return datetime.fromtimestamp(self.optimization_timeout)
|
||||
|
||||
def elapsed(self):
|
||||
# type: () -> float
|
||||
"""
|
||||
Return minutes elapsed from controller stating time stamp
|
||||
|
||||
@@ -853,9 +949,10 @@ class HyperParameterOptimizer(object):
|
||||
"""
|
||||
if self.optimization_start_time is None:
|
||||
return -1.0
|
||||
return (time() - self.optimization_start_time)/60.
|
||||
return (time() - self.optimization_start_time) / 60.
|
||||
|
||||
def reached_time_limit(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return True if we passed the time limit. Function returns immediately, it does not wait for the optimizer.
|
||||
|
||||
@@ -869,6 +966,7 @@ class HyperParameterOptimizer(object):
|
||||
return time() > self.optimization_timeout
|
||||
|
||||
def get_top_experiments(self, top_k):
|
||||
# type: (int) -> Sequence[Task]
|
||||
"""
|
||||
Return a list of Tasks of the top performing experiments, based on the controller Objective object
|
||||
|
||||
@@ -880,9 +978,16 @@ class HyperParameterOptimizer(object):
|
||||
return self.optimizer.get_top_experiments(top_k=top_k)
|
||||
|
||||
def get_optimizer(self):
|
||||
# type: () -> SearchStrategy
|
||||
"""
|
||||
Return the currently used optimizer object
|
||||
|
||||
:return SearchStrategy: Used SearchStrategy object
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
def set_default_job_class(self, job_class):
|
||||
# type: (TrainsJob) -> ()
|
||||
"""
|
||||
Set the Job class to use when the optimizer spawns new Jobs
|
||||
|
||||
@@ -891,6 +996,7 @@ class HyperParameterOptimizer(object):
|
||||
self.optimizer.set_job_class(job_class)
|
||||
|
||||
def set_report_period(self, report_period_minutes):
|
||||
# type: (float) -> ()
|
||||
"""
|
||||
Set reporting period in minutes, for the accumulated objective report
|
||||
This report is sent on the Optimizer Task, and collects objective metric from all running jobs.
|
||||
@@ -900,6 +1006,7 @@ class HyperParameterOptimizer(object):
|
||||
self._report_period_min = float(report_period_minutes)
|
||||
|
||||
def _connect_args(self, optimizer_class=None, hyper_param_configuration=None, **kwargs):
|
||||
# type: (SearchStrategy, dict, Any) -> (SearchStrategy, list, dict)
|
||||
if not self._task:
|
||||
logger.warning('Auto Connect turned on but no Task was found, '
|
||||
'hyper-parameter optimization argument logging disabled')
|
||||
@@ -937,6 +1044,7 @@ class HyperParameterOptimizer(object):
|
||||
return optimizer_class, configuration_dict['parameter_optimization_space'], arguments['opt']
|
||||
|
||||
def _daemon(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
implement the main pooling thread, calling loop every self.pool_period_minutes minutes
|
||||
"""
|
||||
@@ -944,6 +1052,7 @@ class HyperParameterOptimizer(object):
|
||||
self._thread = None
|
||||
|
||||
def _report_daemon(self):
|
||||
# type: () -> ()
|
||||
worker_to_series = {}
|
||||
title, series = self.objective_metric.get_objective_metric()
|
||||
title = '{}/{}'.format(title, series)
|
||||
@@ -956,8 +1065,8 @@ class HyperParameterOptimizer(object):
|
||||
timeout = self.optimization_timeout - time() if self.optimization_timeout else 0.
|
||||
|
||||
if timeout >= 0:
|
||||
timeout = min(self._report_period_min*60., timeout if timeout else self._report_period_min*60.)
|
||||
print('Progress report #{} completed, sleeping for {} minutes'.format(counter, timeout/60.))
|
||||
timeout = min(self._report_period_min * 60., timeout if timeout else self._report_period_min * 60.)
|
||||
print('Progress report #{} completed, sleeping for {} minutes'.format(counter, timeout / 60.))
|
||||
if self._stop_event.wait(timeout=timeout):
|
||||
# wait for one last report
|
||||
timeout = -1
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import sys
|
||||
from itertools import product
|
||||
from random import Random as BaseRandom
|
||||
|
||||
import numpy as np
|
||||
from typing import Mapping, Any, Sequence, Optional, Union
|
||||
|
||||
|
||||
class RandomSeed(object):
|
||||
@@ -14,6 +13,7 @@ class RandomSeed(object):
|
||||
|
||||
@staticmethod
|
||||
def set_random_seed(seed=1337):
|
||||
# type: (int) -> ()
|
||||
"""
|
||||
Set global seed for all hyper parameter strategy random number sampling
|
||||
|
||||
@@ -24,6 +24,7 @@ class RandomSeed(object):
|
||||
|
||||
@staticmethod
|
||||
def get_random_seed():
|
||||
# type: () -> int
|
||||
"""
|
||||
Get the global seed for all hyper parameter strategy random number sampling
|
||||
|
||||
@@ -38,9 +39,16 @@ class Parameter(RandomSeed):
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
# type: (Optional[str]) -> Parameter
|
||||
"""
|
||||
Create a new Parameter for hyper-parameter optimization
|
||||
|
||||
:param str name: give a name to the parameter, this is the parameter name that will be passed to a Task
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def get_value(self):
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""
|
||||
Return a dict with the Parameter name and a sampled value for the Parameter
|
||||
|
||||
@@ -49,16 +57,20 @@ class Parameter(RandomSeed):
|
||||
pass
|
||||
|
||||
def to_list(self):
|
||||
# type: () -> Sequence[Mapping[str, Any]]
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
:return list: list of dicts {name: values}
|
||||
|
||||
:return list: list of dicts {name: value}
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_dict(self):
|
||||
# type: () -> Mapping[str, Union[str, Parameter]]
|
||||
"""
|
||||
Return a dict representation of the parameter object
|
||||
:return: dict
|
||||
Return a dict representation of the parameter object. Used for serialization of the Parameter object.
|
||||
|
||||
:return dict: dict representation of the object (serialization)
|
||||
"""
|
||||
serialize = {'__class__': str(self.__class__).split('.')[-1][:-2]}
|
||||
serialize.update(dict(((k, v.to_dict() if hasattr(v, 'to_dict') else v) for k, v in self.__dict__.items())))
|
||||
@@ -66,9 +78,11 @@ class Parameter(RandomSeed):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, a_dict):
|
||||
# type: (Mapping[str, str]) -> Parameter
|
||||
"""
|
||||
Construct parameter object from a dict representation
|
||||
:return: Parameter
|
||||
Construct parameter object from a dict representation (deserialize from dict)
|
||||
|
||||
:return Parameter: Parameter object
|
||||
"""
|
||||
a_dict = a_dict.copy()
|
||||
a_cls = a_dict.pop('__class__', None)
|
||||
@@ -89,7 +103,15 @@ class UniformParameterRange(Parameter):
|
||||
Uniform randomly sampled Hyper-Parameter object
|
||||
"""
|
||||
|
||||
def __init__(self, name, min_value, max_value, step_size=None, include_max_value=True):
|
||||
def __init__(
|
||||
self,
|
||||
name, # type: str
|
||||
min_value, # type: float
|
||||
max_value, # type: float
|
||||
step_size=None, # type: Optional[float]
|
||||
include_max_value=True # type: bool
|
||||
):
|
||||
# type: (...) -> UniformParameterRange
|
||||
"""
|
||||
Create a parameter to be sampled by the SearchStrategy
|
||||
|
||||
@@ -106,6 +128,7 @@ class UniformParameterRange(Parameter):
|
||||
self.include_max = include_max_value
|
||||
|
||||
def get_value(self):
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""
|
||||
Return uniformly sampled value based on object sampling definitions
|
||||
|
||||
@@ -117,14 +140,15 @@ class UniformParameterRange(Parameter):
|
||||
return {self.name: self.min_value + (self._random.randrange(start=0, stop=round(steps)) * self.step_size)}
|
||||
|
||||
def to_list(self):
|
||||
# type: () -> Sequence[Mapping[str, float]]
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
if self.step_size is not defined, return 100 points between minmax values
|
||||
:return list: list of dicts {name: float}
|
||||
"""
|
||||
values = np.arange(start=self.min_value, stop=self.max_value,
|
||||
step=self.step_size or (self.max_value - self.min_value) / 100.,
|
||||
dtype=type(self.min_value)).tolist()
|
||||
step_size = self.step_size or (self.max_value - self.min_value) / 100.
|
||||
steps = (self.max_value - self.min_value) / self.step_size
|
||||
values = [v*step_size for v in range(0, int(steps))]
|
||||
if self.include_max and (not values or values[-1] < self.max_value):
|
||||
values.append(self.max_value)
|
||||
return [{self.name: v} for v in values]
|
||||
@@ -152,6 +176,7 @@ class UniformIntegerParameterRange(Parameter):
|
||||
self.include_max = include_max_value
|
||||
|
||||
def get_value(self):
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""
|
||||
Return uniformly sampled value based on object sampling definitions
|
||||
|
||||
@@ -162,10 +187,12 @@ class UniformIntegerParameterRange(Parameter):
|
||||
stop=self.max_value + (0 if not self.include_max else self.step_size))}
|
||||
|
||||
def to_list(self):
|
||||
# type: () -> Sequence[Mapping[str, int]]
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
if self.step_size is not defined, return 100 points between minmax values
|
||||
:return list: list of dicts {name: float}
|
||||
|
||||
:return list: list of dicts {name: int}
|
||||
"""
|
||||
values = list(range(self.min_value, self.max_value, self.step_size))
|
||||
if self.include_max and (not values or values[-1] < self.max_value):
|
||||
@@ -189,6 +216,7 @@ class DiscreteParameterRange(Parameter):
|
||||
self.values = values
|
||||
|
||||
def get_value(self):
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""
|
||||
Return uniformly sampled value from the valid list of values
|
||||
|
||||
@@ -197,8 +225,10 @@ class DiscreteParameterRange(Parameter):
|
||||
return {self.name: self._random.choice(self.values)}
|
||||
|
||||
def to_list(self):
|
||||
# type: () -> Sequence[Mapping[str, Any]]
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
|
||||
:return list: list of dicts {name: value}
|
||||
"""
|
||||
return [{self.name: v} for v in self.values]
|
||||
@@ -210,6 +240,7 @@ class ParameterSet(Parameter):
|
||||
"""
|
||||
|
||||
def __init__(self, parameter_combinations=()):
|
||||
# type: (Sequence[Mapping[str, Union[float, int, str, Parameter]]]) -> ParameterSet
|
||||
"""
|
||||
Uniformly sample values form a list of discrete options (combinations) of parameters
|
||||
|
||||
@@ -225,6 +256,7 @@ class ParameterSet(Parameter):
|
||||
self.values = parameter_combinations
|
||||
|
||||
def get_value(self):
|
||||
# type: () -> Mapping[str, Any]
|
||||
"""
|
||||
Return uniformly sampled value from the valid list of values
|
||||
|
||||
@@ -233,8 +265,10 @@ class ParameterSet(Parameter):
|
||||
return self._get_value(self._random.choice(self.values))
|
||||
|
||||
def to_list(self):
|
||||
# type: () -> Sequence[Mapping[str, Any]]
|
||||
"""
|
||||
Return a list of all the valid values of the Parameter
|
||||
|
||||
:return list: list of dicts {name: value}
|
||||
"""
|
||||
combinations = []
|
||||
@@ -253,6 +287,7 @@ class ParameterSet(Parameter):
|
||||
|
||||
@staticmethod
|
||||
def _get_value(combination):
|
||||
# type: (dict) -> dict
|
||||
value_dict = {}
|
||||
for k, v in combination.items():
|
||||
if isinstance(v, Parameter):
|
||||
|
||||
@@ -23,12 +23,12 @@ class ApiServiceProxy(object):
|
||||
if not ApiServiceProxy._max_available_version:
|
||||
from ..backend_api import services
|
||||
ApiServiceProxy._max_available_version = max([
|
||||
Version(name[1:].replace("_", "."))
|
||||
for name in [
|
||||
module_name
|
||||
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
|
||||
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
|
||||
]])
|
||||
Version(name[1:].replace("_", "."))
|
||||
for name in [
|
||||
module_name
|
||||
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
|
||||
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
|
||||
]])
|
||||
|
||||
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
|
||||
self.__dict__["__wrapped_version__"] = Session.api_version
|
||||
|
||||
@@ -59,7 +59,7 @@ class DataModel(object):
|
||||
props = {}
|
||||
for c in cls.__mro__:
|
||||
props.update({k: getattr(v, 'name', k) for k, v in vars(c).items()
|
||||
if isinstance(v, property)})
|
||||
if isinstance(v, property)})
|
||||
cls._data_props_list = props
|
||||
return props.copy()
|
||||
|
||||
@@ -150,6 +150,7 @@ class NonStrictDataModelMixin(object):
|
||||
|
||||
:summary: supplies an __init__ method that warns about unused keywords
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# unexpected = [key for key in kwargs if not key.startswith('_')]
|
||||
# if unexpected:
|
||||
|
||||
@@ -312,8 +312,7 @@ class Config(object):
|
||||
return ConfigFactory.parse_file(file_path)
|
||||
except ParseSyntaxException as ex:
|
||||
msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format(
|
||||
file_path, ex
|
||||
)
|
||||
file_path, ex)
|
||||
six.reraise(
|
||||
ConfigurationError,
|
||||
ConfigurationError(msg, file_path=file_path),
|
||||
|
||||
@@ -273,7 +273,7 @@ class UploadEvent(MetricsEventAdapter):
|
||||
image_data = np.atleast_3d(image_data)
|
||||
if image_data.dtype != np.uint8:
|
||||
if np.issubdtype(image_data.dtype, np.floating) and image_data.max() <= 1.0:
|
||||
image_data = (image_data*255).astype(np.uint8)
|
||||
image_data = (image_data * 255).astype(np.uint8)
|
||||
else:
|
||||
image_data = image_data.astype(np.uint8)
|
||||
shape = image_data.shape
|
||||
@@ -318,7 +318,7 @@ class UploadEvent(MetricsEventAdapter):
|
||||
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
||||
# make sure we preserve local path root
|
||||
if e_storage_uri.startswith('/'):
|
||||
url = '/'+url
|
||||
url = '/' + url
|
||||
|
||||
if quote_uri:
|
||||
url = quote_url(url)
|
||||
|
||||
@@ -155,7 +155,7 @@ class Metrics(InterfaceBase):
|
||||
# upload files
|
||||
def upload(e):
|
||||
upload_uri = e.upload_uri or storage_uri
|
||||
|
||||
|
||||
try:
|
||||
storage = self._get_storage(upload_uri)
|
||||
url = storage.upload_from_stream(e.stream, e.url, retries=self._file_upload_retries)
|
||||
@@ -234,4 +234,3 @@ class Metrics(InterfaceBase):
|
||||
pool.join()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
))
|
||||
self.reload()
|
||||
|
||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None):
|
||||
if tags:
|
||||
extra = {'system_tags': tags or self.data.system_tags} \
|
||||
@@ -318,7 +318,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
callback = partial(
|
||||
self._complete_update_for_task, task_id=task_id, name=name, comment=comment, tags=tags,
|
||||
override_model_id=override_model_id, cb=cb)
|
||||
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable, cb=callback)
|
||||
uri = self._upload_model(model_file, target_filename=target_filename,
|
||||
async_enable=async_enable, cb=callback)
|
||||
return uri
|
||||
else:
|
||||
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
|
||||
|
||||
@@ -16,6 +16,7 @@ class _Arguments(object):
|
||||
|
||||
class _ProxyDictWrite(dict):
|
||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||
|
||||
def __init__(self, arguments, *args, **kwargs):
|
||||
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
|
||||
self._arguments = arguments
|
||||
@@ -27,6 +28,7 @@ class _Arguments(object):
|
||||
|
||||
class _ProxyDictReadOnly(dict):
|
||||
""" Dictionary wrapper that prevents modifications to the dictionary """
|
||||
|
||||
def __init__(self, arguments, *args, **kwargs):
|
||||
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
|
||||
self._arguments = arguments
|
||||
|
||||
@@ -16,7 +16,7 @@ buffer_capacity = config.get('log.task_log_buffer_capacity', 100)
|
||||
class TaskHandler(BufferingHandler):
|
||||
__flush_max_history_seconds = 30.
|
||||
__wait_for_flush_timeout = 10.
|
||||
__max_event_size = 1024*1024
|
||||
__max_event_size = 1024 * 1024
|
||||
__once = False
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ class ScriptInfoError(Exception):
|
||||
|
||||
|
||||
class ScriptRequirements(object):
|
||||
_max_requirements_size = 512*1024
|
||||
_max_requirements_size = 512 * 1024
|
||||
|
||||
def __init__(self, root_folder):
|
||||
self._root_folder = root_folder
|
||||
@@ -365,7 +365,7 @@ class ScriptInfo(object):
|
||||
|
||||
@classmethod
|
||||
def _get_jupyter_notebook_filename(cls):
|
||||
if not (sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or
|
||||
if not (sys.argv[0].endswith(os.path.sep + 'ipykernel_launcher.py') or
|
||||
sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \
|
||||
or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
|
||||
return None
|
||||
|
||||
@@ -119,7 +119,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._curr_label_stats = {}
|
||||
self._raise_on_validation_errors = raise_on_validation_errors
|
||||
self._parameters_allowed_types = (
|
||||
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
|
||||
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
|
||||
)
|
||||
self._app_server = None
|
||||
self._files_server = None
|
||||
@@ -683,7 +683,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
"Only builtin types ({}) are allowed for values (got {})".format(
|
||||
', '.join(t.__name__ for t in self._parameters_allowed_types),
|
||||
', '.join('%s=>%s' % p for p in not_allowed.items())),
|
||||
)
|
||||
)
|
||||
|
||||
# force cast all variables to strings (so that we can later edit them in UI)
|
||||
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
|
||||
@@ -1300,5 +1300,5 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def __is_subprocess(cls):
|
||||
# notice this class function is called from Task.ExitHooks, do not rename/move it.
|
||||
is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \
|
||||
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
|
||||
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
|
||||
return is_subprocess
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -199,6 +199,7 @@ class Artifacts(object):
|
||||
|
||||
class _ProxyDictWrite(dict):
|
||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||
|
||||
def __init__(self, artifacts_manager, *args, **kwargs):
|
||||
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
|
||||
self._artifacts_manager = artifacts_manager
|
||||
@@ -303,8 +304,8 @@ class Artifacts(object):
|
||||
artifact_type_data.content_type = 'application/numpy'
|
||||
artifact_type_data.preview = str(artifact_object.__repr__())
|
||||
override_filename_ext_in_uri = '.npz'
|
||||
override_filename_in_uri = name+override_filename_ext_in_uri
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
|
||||
override_filename_in_uri = name + override_filename_ext_in_uri
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
|
||||
os.close(fd)
|
||||
np.savez_compressed(local_filename, **{name: artifact_object})
|
||||
delete_after_upload = True
|
||||
@@ -314,7 +315,7 @@ class Artifacts(object):
|
||||
artifact_type_data.preview = str(artifact_object.__repr__())
|
||||
override_filename_ext_in_uri = self._save_format
|
||||
override_filename_in_uri = name
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
|
||||
os.close(fd)
|
||||
artifact_object.to_csv(local_filename, compression=self._compression)
|
||||
delete_after_upload = True
|
||||
@@ -325,7 +326,7 @@ class Artifacts(object):
|
||||
artifact_type_data.preview = desc[1:desc.find(' at ')]
|
||||
override_filename_ext_in_uri = '.png'
|
||||
override_filename_in_uri = name + override_filename_ext_in_uri
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
|
||||
os.close(fd)
|
||||
artifact_object.save(local_filename)
|
||||
delete_after_upload = True
|
||||
@@ -335,7 +336,7 @@ class Artifacts(object):
|
||||
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
|
||||
override_filename_ext_in_uri = '.json'
|
||||
override_filename_in_uri = name + override_filename_ext_in_uri
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="")+'.', suffix=override_filename_ext_in_uri)
|
||||
fd, local_filename = mkstemp(prefix=quote(name, safe="") + '.', suffix=override_filename_ext_in_uri)
|
||||
os.write(fd, bytes(preview.encode()))
|
||||
os.close(fd)
|
||||
artifact_type_data.preview = preview
|
||||
@@ -374,7 +375,7 @@ class Artifacts(object):
|
||||
override_filename_ext_in_uri = '.zip'
|
||||
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
|
||||
fd, zip_file = mkstemp(
|
||||
prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri
|
||||
prefix=quote(folder.parts[-1], safe="") + '.', suffix=override_filename_ext_in_uri
|
||||
)
|
||||
try:
|
||||
artifact_type_data.content_type = 'application/zip'
|
||||
@@ -571,7 +572,8 @@ class Artifacts(object):
|
||||
|
||||
artifact_type_data.data_hash = current_sha2
|
||||
artifact_type_data.content_type = "text/csv"
|
||||
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
|
||||
artifact_type_data.preview = str(pd_artifact.__repr__(
|
||||
)) + '\n\n' + self._get_statistics({name: pd_artifact})
|
||||
|
||||
artifact.type_data = artifact_type_data
|
||||
artifact.uri = uri
|
||||
@@ -657,11 +659,11 @@ class Artifacts(object):
|
||||
# build intersection summary
|
||||
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
|
||||
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
|
||||
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
|
||||
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
|
||||
name=name, shape=shape, unique=len(unique_hash), percentage=100 * len(unique_hash) / float(shape[0]))
|
||||
for name2, shape2, unique_hash2 in artifacts_summary[i + 1:]:
|
||||
intersection = len(unique_hash & unique_hash2)
|
||||
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
|
||||
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
|
||||
name2=name2, intersection=intersection, percentage=100 * intersection / float(len(unique_hash2)))
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(str(e))
|
||||
finally:
|
||||
|
||||
@@ -230,8 +230,8 @@ class EventTrainsWriter(object):
|
||||
:return: (str, str) variant and metric
|
||||
"""
|
||||
splitted_tag = tag.split(split_char)
|
||||
if auto_reduce_num_split and num_split_parts > len(splitted_tag)-1:
|
||||
num_split_parts = max(1, len(splitted_tag)-1)
|
||||
if auto_reduce_num_split and num_split_parts > len(splitted_tag) - 1:
|
||||
num_split_parts = max(1, len(splitted_tag) - 1)
|
||||
series = join_char.join(splitted_tag[-num_split_parts:])
|
||||
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
||||
|
||||
@@ -356,7 +356,7 @@ class EventTrainsWriter(object):
|
||||
val = val[:, :, [0, 1, 2]]
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).warning('Failed decoding debug image [%d, %d, %d]'
|
||||
% (width, height, color_channels))
|
||||
% (width, height, color_channels))
|
||||
val = None
|
||||
return val
|
||||
|
||||
@@ -525,7 +525,7 @@ class EventTrainsWriter(object):
|
||||
stream = BytesIO(audio_data)
|
||||
if values:
|
||||
file_extension = guess_extension(values['contentType']) or \
|
||||
'.{}'.format(values['contentType'].split('/')[-1])
|
||||
'.{}'.format(values['contentType'].split('/')[-1])
|
||||
else:
|
||||
# assume wav as default
|
||||
file_extension = '.wav'
|
||||
@@ -548,7 +548,7 @@ class EventTrainsWriter(object):
|
||||
wraparound_counter = EventTrainsWriter._title_series_wraparound_counter[key]
|
||||
# we decide on wrap around if the current step is less than 10% of the previous step
|
||||
# notice since counter is int and we want to avoid rounding error, we have double check in the if
|
||||
if step < wraparound_counter['last_step'] and step < 0.9*wraparound_counter['last_step']:
|
||||
if step < wraparound_counter['last_step'] and step < 0.9 * wraparound_counter['last_step']:
|
||||
# adjust step base line
|
||||
wraparound_counter['adjust_counter'] += wraparound_counter['last_step'] + (1 if step <= 0 else step)
|
||||
|
||||
@@ -582,7 +582,8 @@ class EventTrainsWriter(object):
|
||||
msg_dict.pop('wallTime', None)
|
||||
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
|
||||
keys_list = ', '.join(keys_list)
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('event summary not found, message type unsupported: %s' % keys_list)
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||
'event summary not found, message type unsupported: %s' % keys_list)
|
||||
return
|
||||
value_dicts = summary.get('value')
|
||||
walltime = walltime or msg_dict.get('step')
|
||||
@@ -594,7 +595,8 @@ class EventTrainsWriter(object):
|
||||
step = int(event.step)
|
||||
else:
|
||||
step = 0
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug('Received event without step, assuming step = {}'.format(step))
|
||||
LoggerRoot.get_base_logger(TensorflowBinding).debug(
|
||||
'Received event without step, assuming step = {}'.format(step))
|
||||
else:
|
||||
step = int(step)
|
||||
self._max_step = max(self._max_step, step)
|
||||
@@ -1036,7 +1038,7 @@ class PatchModelCheckPointCallback(object):
|
||||
name=defaults_dict.get('name'),
|
||||
comment=defaults_dict.get('comment'),
|
||||
label_enumeration=defaults_dict.get('label_enumeration') or
|
||||
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
|
||||
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
)
|
||||
output_model.set_upload_destination(
|
||||
@@ -1206,7 +1208,7 @@ class PatchTensorFlowEager(object):
|
||||
for i in range(2, img_data_np.size):
|
||||
img_data = {'width': -1, 'height': -1,
|
||||
'colorspace': 'RGB', 'encodedImageString': img_data_np[i]}
|
||||
image_tag = str(tag)+'/sample_{}'.format(i-2) if img_data_np.size > 3 else str(tag)
|
||||
image_tag = str(tag) + '/sample_{}'.format(i - 2) if img_data_np.size > 3 else str(tag)
|
||||
event_writer._add_image(tag=image_tag,
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
img_data=img_data)
|
||||
@@ -1291,7 +1293,7 @@ class PatchKerasModelIO(object):
|
||||
PatchKerasModelIO._updated_config)
|
||||
if hasattr(Sequential.from_config, '__func__'):
|
||||
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
|
||||
PatchKerasModelIO._from_config))
|
||||
PatchKerasModelIO._from_config))
|
||||
else:
|
||||
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ class PatchedMatplotlib:
|
||||
# check if this is an imshow
|
||||
if hasattr(stored_figure, '_trains_is_imshow'):
|
||||
# flag will be cleared when calling clf() (object will be replaced)
|
||||
stored_figure._trains_is_imshow = max(0, stored_figure._trains_is_imshow-1)
|
||||
stored_figure._trains_is_imshow = max(0, stored_figure._trains_is_imshow - 1)
|
||||
force_save_as_image = True
|
||||
# get current figure
|
||||
mpl_fig = stored_figure.canvas.figure # plt.gcf()
|
||||
@@ -342,7 +342,7 @@ class PatchedMatplotlib:
|
||||
a_plt.savefig(buffer_, dpi=fig_dpi, format=image_format, bbox_inches='tight', pad_inches=0,
|
||||
facecolor=None)
|
||||
buffer_.seek(0)
|
||||
fd, image = mkstemp(suffix='.'+image_format)
|
||||
fd, image = mkstemp(suffix='.' + image_format)
|
||||
os.write(fd, buffer_.read())
|
||||
os.close(fd)
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -127,7 +127,8 @@ def main():
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(
|
||||
'demoapp.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||
@@ -138,7 +139,8 @@ def main():
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(
|
||||
'demoapi.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
|
||||
parsed_host.netloc))
|
||||
|
||||
@@ -42,8 +42,9 @@ def _log_stderr(name, fnc, args, kwargs, is_return):
|
||||
# get a nicer thread id
|
||||
h = int(__thread_id())
|
||||
ts = time.time() - __trace_start
|
||||
__stream_write('{}{:<9.3f}:{:5}:{:8x}: [{}] {}\n'.format('-' if is_return else '',
|
||||
ts, os.getpid(), h, threading.current_thread().name, t))
|
||||
__stream_write('{}{:<9.3f}:{:5}:{:8x}: [{}] {}\n'.format(
|
||||
'-' if is_return else '', ts, os.getpid(),
|
||||
h, threading.current_thread().name, t))
|
||||
if __stream_flush:
|
||||
__stream_flush()
|
||||
except:
|
||||
@@ -167,7 +168,7 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None):
|
||||
# check if this is even in our module
|
||||
if hasattr(fnc, '__module__') and fnc.__module__ != module.__module__:
|
||||
pass # print('not ours {} {}'.format(module, fnc))
|
||||
elif hasattr(fnc, '__qualname__') and fnc.__qualname__.startswith(module.__name__+'.'):
|
||||
elif hasattr(fnc, '__qualname__') and fnc.__qualname__.startswith(module.__name__ + '.'):
|
||||
if isinstance(module.__dict__[fn], classmethod):
|
||||
setattr(module, fn, _traced_call_cls(prefix + fn, fnc))
|
||||
elif isinstance(module.__dict__[fn], staticmethod):
|
||||
@@ -176,7 +177,7 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None):
|
||||
setattr(module, fn, _traced_call_method(prefix + fn, fnc))
|
||||
else:
|
||||
# probably not ours hopefully static function
|
||||
if hasattr(fnc, '__qualname__') and not fnc.__qualname__.startswith(module.__name__+'.'):
|
||||
if hasattr(fnc, '__qualname__') and not fnc.__qualname__.startswith(module.__name__ + '.'):
|
||||
pass # print('not ours {} {}'.format(module, fnc))
|
||||
else:
|
||||
# we should not get here
|
||||
@@ -232,7 +233,7 @@ def trace_trains(stream=None, level=1):
|
||||
# store to stream
|
||||
__stream_write(msg)
|
||||
__stream_write('{:9}:{:5}:{:8}: {:14}\n'.format('seconds', 'pid', 'tid', 'self'))
|
||||
__stream_write('{:9}:{:5}:{:8}:{:15}\n'.format('-'*9, '-'*5, '-'*8, '-'*15))
|
||||
__stream_write('{:9}:{:5}:{:8}:{:15}\n'.format('-' * 9, '-' * 5, '-' * 8, '-' * 15))
|
||||
__trace_start = time.time()
|
||||
|
||||
_patch_module('trains')
|
||||
@@ -267,6 +268,7 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi
|
||||
:param specify_pids: optional list of pids to include
|
||||
"""
|
||||
from glob import glob
|
||||
|
||||
def hash_line(a_line):
|
||||
return hash(':'.join(a_line.split(':')[1:]))
|
||||
|
||||
@@ -304,7 +306,7 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi
|
||||
by_time = {}
|
||||
for p, tids in pids.items():
|
||||
for t, lines in tids.items():
|
||||
ts = float(lines[-1].split(':')[0].strip()) + 0.000001*len(by_time)
|
||||
ts = float(lines[-1].split(':')[0].strip()) + 0.000001 * len(by_time)
|
||||
if print_orphans:
|
||||
for i, l in enumerate(lines):
|
||||
if i > 0 and hash_line(l) in orphan_calls:
|
||||
@@ -313,7 +315,7 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi
|
||||
|
||||
out_stream = open(stream, 'w') if isinstance(stream, str) else stream
|
||||
for k in sorted(by_time.keys()):
|
||||
out_stream.write(by_time[k]+'\n')
|
||||
out_stream.write(by_time[k] + '\n')
|
||||
if isinstance(stream, str):
|
||||
out_stream.close()
|
||||
|
||||
@@ -322,6 +324,7 @@ def end_of_program():
|
||||
# stub
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# from trains import Task
|
||||
# task = Task.init(project_name="examples", task_name="trace test")
|
||||
|
||||
@@ -522,15 +522,15 @@ class Logger(object):
|
||||
"""
|
||||
# check if multiple series
|
||||
multi_series = (
|
||||
isinstance(scatter, list)
|
||||
and (
|
||||
isinstance(scatter[0], np.ndarray)
|
||||
or (
|
||||
scatter[0]
|
||||
and isinstance(scatter[0], list)
|
||||
and isinstance(scatter[0][0], list)
|
||||
)
|
||||
isinstance(scatter, list)
|
||||
and (
|
||||
isinstance(scatter[0], np.ndarray)
|
||||
or (
|
||||
scatter[0]
|
||||
and isinstance(scatter[0], list)
|
||||
and isinstance(scatter[0][0], list)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not multi_series:
|
||||
|
||||
@@ -109,9 +109,7 @@ class BaseModel(object):
|
||||
"""
|
||||
The Id (system UUID) of the model.
|
||||
|
||||
:return: The model id.
|
||||
|
||||
:rtype: str
|
||||
:return str: The model id.
|
||||
"""
|
||||
return self._get_model_data().id
|
||||
|
||||
@@ -121,9 +119,7 @@ class BaseModel(object):
|
||||
"""
|
||||
The name of the model.
|
||||
|
||||
:return: The model name.
|
||||
|
||||
:rtype: str
|
||||
:return str: The model name.
|
||||
"""
|
||||
return self._get_model_data().name
|
||||
|
||||
@@ -143,9 +139,7 @@ class BaseModel(object):
|
||||
"""
|
||||
The comment for the model. Also, use for a model description.
|
||||
|
||||
:return: The model comment / description.
|
||||
|
||||
:rtype: str
|
||||
:return str: The model comment / description.
|
||||
"""
|
||||
return self._get_model_data().comment
|
||||
|
||||
@@ -165,9 +159,7 @@ class BaseModel(object):
|
||||
"""
|
||||
A list of tags describing the model.
|
||||
|
||||
:return: The list of tags.
|
||||
|
||||
:rtype: list(str)
|
||||
:return list(str): The list of tags.
|
||||
"""
|
||||
return self._get_model_data().tags
|
||||
|
||||
@@ -189,9 +181,7 @@ class BaseModel(object):
|
||||
"""
|
||||
The configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
||||
|
||||
:return: The configuration.
|
||||
|
||||
:rtype: str
|
||||
:return str: The configuration.
|
||||
"""
|
||||
return _Model._unwrap_design(self._get_model_data().design)
|
||||
|
||||
@@ -202,9 +192,7 @@ class BaseModel(object):
|
||||
The configuration as a dictionary, parsed from the design text. This usually represents the model configuration.
|
||||
For example, prototxt, an ini file, or Python code to evaluate.
|
||||
|
||||
:return: The configuration.
|
||||
|
||||
:rtype: dict
|
||||
:return str: The configuration.
|
||||
"""
|
||||
return self._text_to_config_dict(self.config_text)
|
||||
|
||||
@@ -215,9 +203,7 @@ class BaseModel(object):
|
||||
The label enumeration of string (label) to integer (value) pairs.
|
||||
|
||||
|
||||
:return: A dictionary containing labels enumeration, where the keys are labels and the values as integers.
|
||||
|
||||
:rtype: dict
|
||||
:return dict: A dictionary containing labels enumeration, where the keys are labels and the values as integers.
|
||||
"""
|
||||
return self._get_model_data().labels
|
||||
|
||||
@@ -266,9 +252,7 @@ class BaseModel(object):
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:return: The locally stored file.
|
||||
|
||||
:rtype: str
|
||||
:return str: The locally stored file.
|
||||
"""
|
||||
# download model (synchronously) and return local file
|
||||
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
|
||||
@@ -288,8 +272,6 @@ class BaseModel(object):
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:return: The model weights, or a list of the locally stored filenames.
|
||||
|
||||
:rtype: package or path
|
||||
"""
|
||||
# check if model was packaged
|
||||
if self._package_tag not in self._get_model_data().tags:
|
||||
@@ -508,7 +490,7 @@ class InputModel(Model):
|
||||
:type config_text: unconstrained text string
|
||||
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
||||
but not both.
|
||||
:param dict label_enumeration: The label enumeration dictionary of string (label) to integer (value) pairs. (Optional)
|
||||
:param dict label_enumeration: Optional label enumeration dictionary of string (label) to integer (value) pairs.
|
||||
|
||||
For example:
|
||||
|
||||
@@ -538,8 +520,6 @@ class InputModel(Model):
|
||||
:type framework: str or Framework object
|
||||
|
||||
:return: The imported model or existing model (see above).
|
||||
|
||||
:rtype: A model object.
|
||||
"""
|
||||
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||
weights_url = StorageHelper.conform_url(weights_url)
|
||||
@@ -578,7 +558,7 @@ class InputModel(Model):
|
||||
from .task import Task
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
comment = 'Imported by task id: {}'.format(task.id) + ('\n'+comment if comment else '')
|
||||
comment = 'Imported by task id: {}'.format(task.id) + ('\n' + comment if comment else '')
|
||||
project_id = task.project
|
||||
task_id = task.id
|
||||
else:
|
||||
@@ -776,9 +756,7 @@ class OutputModel(BaseModel):
|
||||
"""
|
||||
Get the published state of this model.
|
||||
|
||||
:return: ``True`` if the model is published, ``False`` otherwise.
|
||||
|
||||
:rtype: bool
|
||||
:return bool: ``True`` if the model is published, ``False`` otherwise.
|
||||
"""
|
||||
if not self.id:
|
||||
return False
|
||||
@@ -790,9 +768,7 @@ class OutputModel(BaseModel):
|
||||
"""
|
||||
Get the configuration as a string. For example, prototxt, an ini file, or Python code to evaluate.
|
||||
|
||||
:return: The configuration.
|
||||
|
||||
:rtype: str
|
||||
:return str: The configuration.
|
||||
"""
|
||||
return _Model._unwrap_design(self._get_model_data().design)
|
||||
|
||||
@@ -811,9 +787,7 @@ class OutputModel(BaseModel):
|
||||
Get the configuration as a dictionary parsed from the ``config_text`` text. This usually represents the model
|
||||
configuration. For example, from prototxt to ini file or python code to evaluate.
|
||||
|
||||
:return: The configuration.
|
||||
|
||||
:rtype: dict
|
||||
:return dict: The configuration.
|
||||
"""
|
||||
return self._text_to_config_dict(self.config_text)
|
||||
|
||||
@@ -842,9 +816,7 @@ class OutputModel(BaseModel):
|
||||
'person': 1
|
||||
}
|
||||
|
||||
:return: The label enumeration.
|
||||
|
||||
:rtype: dict
|
||||
:return dict: The label enumeration.
|
||||
"""
|
||||
return self._get_model_data().labels
|
||||
|
||||
@@ -889,7 +861,8 @@ class OutputModel(BaseModel):
|
||||
Create a new model and immediately connect it to a task.
|
||||
|
||||
We do not allow for Model creation without a task, so we always keep track on how we created the models
|
||||
In remote execution, Model parameters can be overridden by the Task (such as model configuration & label enumerator)
|
||||
In remote execution, Model parameters can be overridden by the Task
|
||||
(such as model configuration & label enumerator)
|
||||
|
||||
:param task: The Task object with which the OutputModel object is associated.
|
||||
:type task: Task
|
||||
@@ -979,12 +952,12 @@ class OutputModel(BaseModel):
|
||||
if running_remotely() and task.is_main_task():
|
||||
if self._floating_data:
|
||||
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \
|
||||
self._floating_data.design
|
||||
self._floating_data.design
|
||||
self._floating_data.labels = self._task.get_labels_enumeration() or \
|
||||
self._floating_data.labels
|
||||
self._floating_data.labels
|
||||
elif self._base_model:
|
||||
self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or
|
||||
self._base_model.design)
|
||||
self._base_model.design)
|
||||
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
|
||||
|
||||
elif self._floating_data is not None:
|
||||
@@ -1005,8 +978,8 @@ class OutputModel(BaseModel):
|
||||
def set_upload_destination(self, uri):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Set the URI of the storage destination for uploaded model weight files. Supported storage destinations include
|
||||
S3, Google Cloud Storage), and file locations.
|
||||
Set the URI of the storage destination for uploaded model weight files.
|
||||
Supported storage destinations include S3, Google Cloud Storage), and file locations.
|
||||
|
||||
Using this method, files uploads are separate and then a link to each is stored in the model object.
|
||||
|
||||
@@ -1021,12 +994,10 @@ class OutputModel(BaseModel):
|
||||
- ``s3://bucket/directory/``
|
||||
- ``file:///tmp/debug/``
|
||||
|
||||
:return: The status of whether the storage destination schema is supported.
|
||||
:return bool: The status of whether the storage destination schema is supported.
|
||||
|
||||
- ``True`` - The storage destination scheme is supported.
|
||||
- ``False`` - The storage destination scheme is not supported.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
if not uri:
|
||||
return
|
||||
@@ -1063,8 +1034,8 @@ class OutputModel(BaseModel):
|
||||
.. note::
|
||||
Uploading the model is a background process. A call to this method returns immediately.
|
||||
|
||||
:param str weights_filename: The name of the locally stored weights file to upload. Specify ``weights_filename``
|
||||
or ``register_uri``, but not both.
|
||||
:param str weights_filename: The name of the locally stored weights file to upload.
|
||||
Specify ``weights_filename`` or ``register_uri``, but not both.
|
||||
:param str upload_uri: The URI of the storage destination for model weights upload. The default value
|
||||
is the previously used URI. (Optional)
|
||||
:param str target_filename: The newly created filename in the storage destination location. The default value
|
||||
@@ -1083,9 +1054,7 @@ class OutputModel(BaseModel):
|
||||
- ``True`` - Update model comment (Default)
|
||||
- ``False`` - Do not update
|
||||
|
||||
:return: The uploaded URI.
|
||||
|
||||
:rtype: str
|
||||
:return str: The uploaded URI.
|
||||
"""
|
||||
|
||||
def delete_previous_weights_file(filename=weights_filename):
|
||||
@@ -1120,7 +1089,8 @@ class OutputModel(BaseModel):
|
||||
if not model:
|
||||
raise ValueError('Failed creating internal output model')
|
||||
|
||||
# select the correct file extension based on the framework, or update the framework based on the file extension
|
||||
# select the correct file extension based on the framework,
|
||||
# or update the framework based on the file extension
|
||||
framework, file_ext = Framework._get_file_ext(
|
||||
framework=self._get_model_data().framework,
|
||||
filename=target_filename or weights_filename or register_uri
|
||||
@@ -1211,13 +1181,12 @@ class OutputModel(BaseModel):
|
||||
|
||||
:param int iteration: The iteration number.
|
||||
|
||||
:return: The uploaded URI for the weights package.
|
||||
|
||||
:rtype: str
|
||||
:return str: The uploaded URI for the weights package.
|
||||
"""
|
||||
# create list of files
|
||||
if (not weights_filenames and not weights_path) or (weights_filenames and weights_path):
|
||||
raise ValueError('Model update weights package should get either directory path to pack or a list of files')
|
||||
raise ValueError('Model update weights package should get either '
|
||||
'directory path to pack or a list of files')
|
||||
|
||||
if not weights_filenames:
|
||||
weights_filenames = list(map(six.text_type, Path(weights_path).rglob('*')))
|
||||
@@ -1276,14 +1245,13 @@ class OutputModel(BaseModel):
|
||||
:param dict config_dict: The configuration as a dictionary. Specify ``config_text`` or ``config_dict``,
|
||||
but not both.
|
||||
|
||||
:return: The status of the update.
|
||||
:return bool: The status of the update.
|
||||
|
||||
- ``True`` - Update successful.
|
||||
- ``False`` - Update not successful.
|
||||
:rtype: bool
|
||||
"""
|
||||
if not self._validate_update():
|
||||
return
|
||||
return False
|
||||
|
||||
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||
|
||||
|
||||
@@ -575,27 +575,27 @@ class StorageHelper(object):
|
||||
def list(self, prefix=None):
|
||||
"""
|
||||
List entries in the helper base path.
|
||||
|
||||
|
||||
Return a list of names inside this helper base path. The base path is
|
||||
determined at creation time and is specific for each storage medium.
|
||||
For Google Storage and S3 it is the bucket of the path.
|
||||
For local files it is the root directory.
|
||||
|
||||
|
||||
This operation is not supported for http and https protocols.
|
||||
|
||||
|
||||
:param prefix: If None, return the list as described above. If not, it
|
||||
must be a string - the path of a sub directory under the base path.
|
||||
the returned list will include only objects under that subdir.
|
||||
|
||||
|
||||
:return: List of strings - the paths of all the objects in the storage base
|
||||
path under prefix. Listed relative to the base path.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
if prefix:
|
||||
if prefix.startswith(self._base_url):
|
||||
prefix = prefix[len(self.base_url):].lstrip("/")
|
||||
|
||||
|
||||
try:
|
||||
res = self._driver.list_container_objects(self._container, ex_prefix=prefix)
|
||||
except TypeError:
|
||||
@@ -943,7 +943,7 @@ class StorageHelper(object):
|
||||
cb(dest_path)
|
||||
except Exception as e:
|
||||
self._log.warning("Exception on upload callback: %s" % str(e))
|
||||
|
||||
|
||||
return dest_path
|
||||
|
||||
def _get_object(self, path):
|
||||
@@ -1012,8 +1012,8 @@ class _HttpDriver(_Driver):
|
||||
|
||||
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
||||
url = object_name[:object_name.index('/')]
|
||||
url_path = object_name[len(url)+1:]
|
||||
full_url = container.name+url
|
||||
url_path = object_name[len(url) + 1:]
|
||||
full_url = container.name + url
|
||||
# when sending data in post, there is no connection timeout, just an entire upload timeout
|
||||
timeout = self.timeout[-1]
|
||||
if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'):
|
||||
@@ -1668,7 +1668,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
|
||||
def callback_func(current, total):
|
||||
if callback:
|
||||
chunk = current-download_done.counter
|
||||
chunk = current - download_done.counter
|
||||
download_done.counter += chunk
|
||||
callback(chunk)
|
||||
if current >= total:
|
||||
|
||||
@@ -618,16 +618,22 @@ class Task(_Task):
|
||||
return self.get_models()
|
||||
|
||||
@classmethod
|
||||
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
|
||||
# type: (Optional[Task], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> Task
|
||||
def clone(
|
||||
cls,
|
||||
source_task=None, # type: Optional[Union[Task, str]]
|
||||
name=None, # type: Optional[str]
|
||||
comment=None, # type: Optional[str]
|
||||
parent=None, # type: Optional[str]
|
||||
project=None, # type: Optional[str]
|
||||
):
|
||||
# type: (...) -> Task
|
||||
"""
|
||||
Create a duplicate (a clone) of a Task (experiment). The status of the cloned Task is ``Draft``
|
||||
and modifiable.
|
||||
|
||||
Use this method to manage experiments and for autoML.
|
||||
|
||||
:param source_task: The Task to clone. Specify a Task object or a Task Id. (Optional)
|
||||
:type source_task: Task/str
|
||||
:param str source_task: The Task to clone. Specify a Task object or a Task Id. (Optional)
|
||||
:param str name: The name of the new cloned Task. (Optional)
|
||||
:param str comment: A comment / description for the new cloned Task. (Optional)
|
||||
:param str parent: The Id of the parent Task of the new Task.
|
||||
|
||||
@@ -37,7 +37,7 @@ def range_validator(min_value, max_value):
|
||||
"""
|
||||
def _range_validator(instance, attribute, value):
|
||||
if ((min_value is not None) and (value < min_value)) or \
|
||||
((max_value is not None) and (value > max_value)):
|
||||
((max_value is not None) and (value > max_value)):
|
||||
raise ValueError("{} must be in range [{}, {}]".format(attribute.name, min_value, max_value))
|
||||
|
||||
return _range_validator
|
||||
@@ -159,5 +159,3 @@ class TaskParameters(object):
|
||||
"""
|
||||
|
||||
return task.connect(self)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class PatchArgumentParser:
|
||||
def add_subparsers(self, **kwargs):
|
||||
if 'dest' not in kwargs:
|
||||
if kwargs.get('title'):
|
||||
kwargs['dest'] = '/'+kwargs['title']
|
||||
kwargs['dest'] = '/' + kwargs['title']
|
||||
else:
|
||||
PatchArgumentParser._add_subparsers_counter += 1
|
||||
kwargs['dest'] = '/subparser%d' % PatchArgumentParser._add_subparsers_counter
|
||||
@@ -80,14 +80,16 @@ class PatchArgumentParser:
|
||||
# if we do we need to patch the args, because there is no default subparser
|
||||
if PY2:
|
||||
import itertools
|
||||
|
||||
def _get_sub_parsers_defaults(subparser, prev=[]):
|
||||
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(subparser, _SubParsersAction) else \
|
||||
[subparser._actions]
|
||||
sub_parsers_defaults = [[subparser]] if hasattr(subparser, 'default') and subparser.default else []
|
||||
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(
|
||||
subparser, _SubParsersAction) else [subparser._actions]
|
||||
sub_parsers_defaults = [[subparser]] if hasattr(
|
||||
subparser, 'default') and subparser.default else []
|
||||
for actions in actions_grp:
|
||||
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev)
|
||||
for a in actions if isinstance(a, _SubParsersAction) and
|
||||
hasattr(a, 'default') and a.default]
|
||||
for a in actions if isinstance(a, _SubParsersAction) and
|
||||
hasattr(a, 'default') and a.default]
|
||||
|
||||
return list(itertools.chain.from_iterable(sub_parsers_defaults))
|
||||
sub_parsers_defaults = _get_sub_parsers_defaults(self)
|
||||
|
||||
@@ -39,8 +39,8 @@ def get_percentage(config, key, required=True, default=None):
|
||||
|
||||
def get_human_size_default(config, key, default=None):
|
||||
raw_value = config.get(key, default)
|
||||
|
||||
|
||||
if raw_value is None:
|
||||
return default
|
||||
|
||||
return parse_human_size(raw_value)
|
||||
|
||||
return parse_human_size(raw_value)
|
||||
|
||||
@@ -118,4 +118,3 @@ class DeferredExecution(object):
|
||||
return func(instance, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ class BlobsDict(dict):
|
||||
"""
|
||||
Overloading getitem so that the 'data' copy is only done when the dictionary item is accessed.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BlobsDict, self).__init__(*args, **kwargs)
|
||||
|
||||
@@ -60,6 +61,7 @@ class BlobsDict(dict):
|
||||
class NestedBlobsDict(BlobsDict):
|
||||
"""A dictionary that applies an arbitrary key-altering function
|
||||
before accessing the keys."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(NestedBlobsDict, self).__init__(*args, **kwargs)
|
||||
|
||||
@@ -96,14 +98,14 @@ class NestedBlobsDict(BlobsDict):
|
||||
for key in cur_keys:
|
||||
if isinstance(cur_dict[key], dict):
|
||||
if len(path) > 0:
|
||||
deep_keys.extend(self._keys(cur_dict[key], path+ '.' + key))
|
||||
deep_keys.extend(self._keys(cur_dict[key], path + '.' + key))
|
||||
else:
|
||||
deep_keys.extend(self._keys(cur_dict[key], key))
|
||||
else:
|
||||
if len(path) > 0:
|
||||
deep_keys.append(path + '.' + key)
|
||||
else:
|
||||
deep_keys.append( key)
|
||||
deep_keys.append(key)
|
||||
|
||||
return deep_keys
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ import threading
|
||||
import string
|
||||
|
||||
## C Type mappings ##
|
||||
## Enums
|
||||
# Enums
|
||||
_nvmlEnableState_t = c_uint
|
||||
NVML_FEATURE_DISABLED = 0
|
||||
NVML_FEATURE_ENABLED = 1
|
||||
@@ -347,7 +347,7 @@ def _nvmlGetFunctionPointer(name):
|
||||
libLoadLock.release()
|
||||
|
||||
|
||||
## Alternative object
|
||||
# Alternative object
|
||||
# Allows the object to be printed
|
||||
# Allows mismatched types to be assigned
|
||||
# - like None when the Structure variant requires c_uint
|
||||
@@ -379,7 +379,7 @@ def nvmlFriendlyObjectToStruct(obj, model):
|
||||
return model
|
||||
|
||||
|
||||
## Unit structures
|
||||
# Unit structures
|
||||
class struct_c_nvmlUnit_t(Structure):
|
||||
pass # opaque handle
|
||||
|
||||
@@ -461,7 +461,7 @@ class c_nvmlUnitFanSpeeds_t(_PrintableStructure):
|
||||
]
|
||||
|
||||
|
||||
## Device structures
|
||||
# Device structures
|
||||
class struct_c_nvmlDevice_t(Structure):
|
||||
pass # opaque handle
|
||||
|
||||
@@ -591,7 +591,7 @@ class c_nvmlViolationTime_t(_PrintableStructure):
|
||||
]
|
||||
|
||||
|
||||
## Event structures
|
||||
# Event structures
|
||||
class struct_c_nvmlEventSet_t(Structure):
|
||||
pass # opaque handle
|
||||
|
||||
@@ -605,29 +605,30 @@ nvmlEventTypeXidCriticalError = 0x0000000000000008
|
||||
nvmlEventTypeClock = 0x0000000000000010
|
||||
nvmlEventTypeNone = 0x0000000000000000
|
||||
nvmlEventTypeAll = (
|
||||
nvmlEventTypeNone |
|
||||
nvmlEventTypeSingleBitEccError |
|
||||
nvmlEventTypeDoubleBitEccError |
|
||||
nvmlEventTypePState |
|
||||
nvmlEventTypeClock |
|
||||
nvmlEventTypeXidCriticalError
|
||||
nvmlEventTypeNone |
|
||||
nvmlEventTypeSingleBitEccError |
|
||||
nvmlEventTypeDoubleBitEccError |
|
||||
nvmlEventTypePState |
|
||||
nvmlEventTypeClock |
|
||||
nvmlEventTypeXidCriticalError
|
||||
)
|
||||
|
||||
## Clock Throttle Reasons defines
|
||||
# Clock Throttle Reasons defines
|
||||
nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001
|
||||
nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002
|
||||
nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting
|
||||
# deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting
|
||||
nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting
|
||||
nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004
|
||||
nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008
|
||||
nvmlClocksThrottleReasonUnknown = 0x8000000000000000
|
||||
nvmlClocksThrottleReasonNone = 0x0000000000000000
|
||||
nvmlClocksThrottleReasonAll = (
|
||||
nvmlClocksThrottleReasonNone |
|
||||
nvmlClocksThrottleReasonGpuIdle |
|
||||
nvmlClocksThrottleReasonApplicationsClocksSetting |
|
||||
nvmlClocksThrottleReasonSwPowerCap |
|
||||
nvmlClocksThrottleReasonHwSlowdown |
|
||||
nvmlClocksThrottleReasonUnknown
|
||||
nvmlClocksThrottleReasonNone |
|
||||
nvmlClocksThrottleReasonGpuIdle |
|
||||
nvmlClocksThrottleReasonApplicationsClocksSetting |
|
||||
nvmlClocksThrottleReasonSwPowerCap |
|
||||
nvmlClocksThrottleReasonHwSlowdown |
|
||||
nvmlClocksThrottleReasonUnknown
|
||||
)
|
||||
|
||||
|
||||
@@ -785,7 +786,7 @@ def nvmlSystemGetHicVersion():
|
||||
return hics
|
||||
|
||||
|
||||
## Unit get functions
|
||||
# Unit get functions
|
||||
def nvmlUnitGetCount():
|
||||
c_count = c_uint()
|
||||
fn = _nvmlGetFunctionPointer("nvmlUnitGetCount")
|
||||
@@ -865,7 +866,7 @@ def nvmlUnitGetDevices(unit):
|
||||
return c_devices
|
||||
|
||||
|
||||
## Device get functions
|
||||
# Device get functions
|
||||
def nvmlDeviceGetCount():
|
||||
c_count = c_uint()
|
||||
fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2")
|
||||
@@ -919,7 +920,7 @@ def nvmlDeviceGetName(handle):
|
||||
|
||||
|
||||
def nvmlDeviceGetBoardId(handle):
|
||||
c_id = c_uint();
|
||||
c_id = c_uint()
|
||||
fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId")
|
||||
ret = fn(handle, byref(c_id))
|
||||
_nvmlCheckReturn(ret)
|
||||
@@ -927,7 +928,7 @@ def nvmlDeviceGetBoardId(handle):
|
||||
|
||||
|
||||
def nvmlDeviceGetMultiGpuBoard(handle):
|
||||
c_multiGpu = c_uint();
|
||||
c_multiGpu = c_uint()
|
||||
fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard")
|
||||
ret = fn(handle, byref(c_multiGpu))
|
||||
_nvmlCheckReturn(ret)
|
||||
@@ -1480,7 +1481,7 @@ def nvmlDeviceGetAutoBoostedClocksEnabled(handle):
|
||||
# Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks
|
||||
|
||||
|
||||
## Set functions
|
||||
# Set functions
|
||||
def nvmlUnitSetLedState(unit, color):
|
||||
fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState")
|
||||
ret = fn(unit, _nvmlLedColor_t(color))
|
||||
@@ -1800,7 +1801,7 @@ def nvmlDeviceGetSamples(device, sampling_type, timeStamp):
|
||||
c_sample_value_type = _nvmlValueType_t()
|
||||
fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples")
|
||||
|
||||
## First Call gets the size
|
||||
# First Call gets the size
|
||||
ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), None)
|
||||
|
||||
# Stop if this fails
|
||||
@@ -1819,7 +1820,7 @@ def nvmlDeviceGetViolationStatus(device, perfPolicyType):
|
||||
c_violTime = c_nvmlViolationTime_t()
|
||||
fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus")
|
||||
|
||||
## Invoke the method to get violation time
|
||||
# Invoke the method to get violation time
|
||||
ret = fn(device, c_perfPolicy_type, byref(c_violTime))
|
||||
_nvmlCheckReturn(ret)
|
||||
return c_violTime
|
||||
|
||||
@@ -148,4 +148,3 @@ elif os.name == 'posix': # pragma: no cover
|
||||
|
||||
else: # pragma: no cover
|
||||
raise RuntimeError('PortaLocker only defined for nt and posix platforms')
|
||||
|
||||
|
||||
@@ -218,6 +218,7 @@ class RLock(Lock):
|
||||
can be acquired multiple times. When the corresponding number of release()
|
||||
calls are made the lock will finally release the underlying file lock.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, filename, mode='a', timeout=DEFAULT_TIMEOUT,
|
||||
check_interval=DEFAULT_CHECK_INTERVAL, fail_when_locked=False,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ def project_import_modules(project_path, ignores):
|
||||
continue
|
||||
# Hack detect if this is a virtual-env folder, if so add it to the uignore list
|
||||
if set(dirnames) == venv_subdirs:
|
||||
ignore_absolute.append(Path(dirpath).as_posix()+os.sep)
|
||||
ignore_absolute.append(Path(dirpath).as_posix() + os.sep)
|
||||
continue
|
||||
|
||||
py_files = list()
|
||||
@@ -132,7 +132,7 @@ class ImportChecker(object):
|
||||
level -= 1
|
||||
mod_name = ''
|
||||
for alias in node.names:
|
||||
name = level*'.' + mod_name + '.' + alias.name
|
||||
name = level * '.' + mod_name + '.' + alias.name
|
||||
self._modules.add(name, self._fpath, node.lineno + self._lineno)
|
||||
if try_:
|
||||
self._try_imports.add(name)
|
||||
|
||||
@@ -92,7 +92,7 @@ class Archive(object):
|
||||
def is_safe(self, filename):
|
||||
return not (filename.startswith(("/", "\\")) or
|
||||
(len(filename) > 1 and filename[1] == ":" and
|
||||
filename[0] in string.ascii_letter) or
|
||||
filename[0] in string.ascii_letter) or
|
||||
re.search(r"[.][.][/\\]", filename))
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
@@ -102,11 +102,11 @@ def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=
|
||||
for s in series:
|
||||
# if we need to down-sample, use low-pass average filter and sampling
|
||||
if s.data.size >= base_size:
|
||||
budget = int(leftover * s.data.size/(total_size-baseused_size))
|
||||
budget = int(leftover * s.data.size / (total_size - baseused_size))
|
||||
step = int(np.ceil(s.data.size / float(budget)))
|
||||
x = s.data[:, 0][::-step][::-1]
|
||||
y = s.data[:, 1]
|
||||
y_low_pass = np.convolve(y, np.ones(shape=(step,), dtype=y.dtype)/float(step), mode='same')
|
||||
y_low_pass = np.convolve(y, np.ones(shape=(step,), dtype=y.dtype) / float(step), mode='same')
|
||||
y = y_low_pass[::-step][::-1]
|
||||
s.data = np.array([x, y], dtype=s.data.dtype).T
|
||||
|
||||
@@ -186,7 +186,8 @@ def create_3d_scatter_series(np_row_wise, title="Scatter", series_name="Series",
|
||||
:return:
|
||||
"""
|
||||
if not plotly_obj:
|
||||
plotly_obj = plotly_scatter3d_layout_dict(title=title, xaxis_title=xtitle, yaxis_title=ytitle, zaxis_title=ztitle)
|
||||
plotly_obj = plotly_scatter3d_layout_dict(
|
||||
title=title, xaxis_title=xtitle, yaxis_title=ytitle, zaxis_title=ztitle)
|
||||
assert np_row_wise.ndim == 2, "Expected a 2D numpy array"
|
||||
assert np_row_wise.shape[1] == 3, "Expected three columns X/Y/Z e.g. [(x0,y0,z0), (x1,y1,z1) ...]"
|
||||
|
||||
@@ -282,7 +283,7 @@ def create_3d_surface(np_value_matrix, title="3D Surface", xlabels=None, ylabels
|
||||
"yaxis": {
|
||||
"title": ytitle,
|
||||
"showgrid": False,
|
||||
"nticks": 10,
|
||||
"nticks": 10,
|
||||
"ticktext": ylabels,
|
||||
"tickvals": list(range(len(ylabels))) if ylabels else ylabels,
|
||||
},
|
||||
|
||||
@@ -74,7 +74,7 @@ class ProxyDictPreWrite(dict):
|
||||
return key_value
|
||||
|
||||
def _nested_callback(self, prefix, key_value):
|
||||
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))
|
||||
return self._set_callback((prefix + '.' + key_value[0], key_value[1],))
|
||||
|
||||
|
||||
def flatten_dictionary(a_dict, prefix=''):
|
||||
@@ -84,15 +84,15 @@ def flatten_dictionary(a_dict, prefix=''):
|
||||
for k, v in a_dict.items():
|
||||
k = str(k)
|
||||
if isinstance(v, (float, int, bool, six.string_types)):
|
||||
flat_dict[prefix+k] = v
|
||||
flat_dict[prefix + k] = v
|
||||
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
|
||||
flat_dict[prefix+k] = v
|
||||
flat_dict[prefix + k] = v
|
||||
elif isinstance(v, dict):
|
||||
flat_dict.update(flatten_dictionary(v, prefix=prefix+k+sep))
|
||||
flat_dict.update(flatten_dictionary(v, prefix=prefix + k + sep))
|
||||
else:
|
||||
# this is a mixture of list and dict, or any other object,
|
||||
# leave it as is, we have nothing to do with it.
|
||||
flat_dict[prefix+k] = v
|
||||
flat_dict[prefix + k] = v
|
||||
return flat_dict
|
||||
|
||||
|
||||
@@ -102,15 +102,15 @@ def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
|
||||
for k, v in a_dict.items():
|
||||
k = str(k)
|
||||
if isinstance(v, (float, int, bool, six.string_types)):
|
||||
a_dict[k] = flat_dict.get(prefix+k, v)
|
||||
a_dict[k] = flat_dict.get(prefix + k, v)
|
||||
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
|
||||
a_dict[k] = flat_dict.get(prefix+k, v)
|
||||
a_dict[k] = flat_dict.get(prefix + k, v)
|
||||
elif isinstance(v, dict):
|
||||
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix+k+sep) or v
|
||||
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix + k + sep) or v
|
||||
else:
|
||||
# this is a mixture of list and dict, or any other object,
|
||||
# leave it as is, we have nothing to do with it.
|
||||
a_dict[k] = flat_dict.get(prefix+k, v)
|
||||
a_dict[k] = flat_dict.get(prefix + k, v)
|
||||
return a_dict
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ def naive_nested_from_flat_dictionary(flat_dict, sep='/'):
|
||||
bucket[0][1] if (len(bucket) == 1 and sub_prefix == bucket[0][0])
|
||||
else naive_nested_from_flat_dictionary(
|
||||
{
|
||||
k[len(sub_prefix)+1:]: v
|
||||
k[len(sub_prefix) + 1:]: v
|
||||
for k, v in bucket
|
||||
if len(k) > len(sub_prefix)
|
||||
}
|
||||
|
||||
@@ -366,7 +366,7 @@ class ConfigParser(object):
|
||||
null_expr = Keyword("null", caseless=True).setParseAction(replaceWith(NoneValue()))
|
||||
# key = QuotedString('"', escChar='\\', unquoteResults=False) | Word(alphanums + alphas8bit + '._- /')
|
||||
key = QuotedString('"', escChar='\\', unquoteResults=False) | \
|
||||
Word("0123456789.").setParseAction(safe_convert_number) | Word(alphanums + alphas8bit + '._- /')
|
||||
Word("0123456789.").setParseAction(safe_convert_number) | Word(alphanums + alphas8bit + '._- /')
|
||||
|
||||
eol = Word('\n\r').suppress()
|
||||
eol_comma = Word('\n\r,').suppress()
|
||||
@@ -390,13 +390,15 @@ class ConfigParser(object):
|
||||
# line1 \
|
||||
# line2 \
|
||||
# so a backslash precedes the \n
|
||||
unquoted_string = Regex(r'(?:[^^`+?!@*&"\[\{\s\]\}#,=\$\\]|\\.)+[ \t]*', re.UNICODE).setParseAction(unescape_string)
|
||||
unquoted_string = Regex(r'(?:[^^`+?!@*&"\[\{\s\]\}#,=\$\\]|\\.)+[ \t]*',
|
||||
re.UNICODE).setParseAction(unescape_string)
|
||||
substitution_expr = Regex(r'[ \t]*\$\{[^\}]+\}[ \t]*').setParseAction(create_substitution)
|
||||
string_expr = multiline_string | quoted_string | unquoted_string
|
||||
|
||||
value_expr = period_expr | number_expr | true_expr | false_expr | null_expr | string_expr
|
||||
|
||||
include_content = (quoted_string | ((Keyword('url') | Keyword('file')) - Literal('(').suppress() - quoted_string - Literal(')').suppress()))
|
||||
include_content = (quoted_string | ((Keyword('url') | Keyword(
|
||||
'file')) - Literal('(').suppress() - quoted_string - Literal(')').suppress()))
|
||||
include_expr = (
|
||||
Keyword("include", caseless=True).suppress() + (
|
||||
include_content | (
|
||||
@@ -408,33 +410,34 @@ class ConfigParser(object):
|
||||
root_dict_expr = Forward()
|
||||
dict_expr = Forward()
|
||||
list_expr = Forward()
|
||||
multi_value_expr = ZeroOrMore(comment_eol | include_expr | substitution_expr | dict_expr | list_expr | value_expr | (Literal(
|
||||
'\\') - eol).suppress())
|
||||
multi_value_expr = ZeroOrMore(comment_eol | include_expr | substitution_expr |
|
||||
dict_expr | list_expr | value_expr | (Literal('\\') - eol).suppress())
|
||||
# for a dictionary : or = is optional
|
||||
# last zeroOrMore is because we can have t = {a:4} {b: 6} {c: 7} which is dictionary concatenation
|
||||
inside_dict_expr = ConfigTreeParser(ZeroOrMore(comment_eol | include_expr | assign_expr | eol_comma))
|
||||
inside_root_dict_expr = ConfigTreeParser(ZeroOrMore(comment_eol | include_expr | assign_expr | eol_comma), root=True)
|
||||
inside_root_dict_expr = ConfigTreeParser(ZeroOrMore(
|
||||
comment_eol | include_expr | assign_expr | eol_comma), root=True)
|
||||
dict_expr << Suppress('{') - inside_dict_expr - Suppress('}')
|
||||
root_dict_expr << Suppress('{') - inside_root_dict_expr - Suppress('}')
|
||||
list_entry = ConcatenatedValueParser(multi_value_expr)
|
||||
list_expr << Suppress('[') - ListParser(list_entry - ZeroOrMore(eol_comma - list_entry)) - Suppress(']')
|
||||
|
||||
# special case when we have a value assignment where the string can potentially be the remainder of the line
|
||||
assign_expr << Group(
|
||||
key - ZeroOrMore(comment_no_comma_eol) - (dict_expr | (Literal('=') | Literal(':') | Literal('+=')) - ZeroOrMore(
|
||||
comment_no_comma_eol) - ConcatenatedValueParser(multi_value_expr))
|
||||
)
|
||||
assign_expr << Group(key - ZeroOrMore(comment_no_comma_eol) -
|
||||
(dict_expr | (Literal('=') | Literal(':') | Literal('+=')) -
|
||||
ZeroOrMore(comment_no_comma_eol) - ConcatenatedValueParser(multi_value_expr)))
|
||||
|
||||
# the file can be { ... } where {} can be omitted or []
|
||||
config_expr = ZeroOrMore(comment_eol | eol) + (list_expr | root_dict_expr | inside_root_dict_expr) + ZeroOrMore(
|
||||
comment_eol | eol_comma)
|
||||
config_expr = ZeroOrMore(comment_eol | eol) + (list_expr | root_dict_expr |
|
||||
inside_root_dict_expr) + ZeroOrMore(comment_eol | eol_comma)
|
||||
config = config_expr.parseString(content, parseAll=True)[0]
|
||||
|
||||
if resolve:
|
||||
allow_unresolved = resolve and unresolved_value is not DEFAULT_SUBSTITUTION and unresolved_value is not MANDATORY_SUBSTITUTION
|
||||
has_unresolved = cls.resolve_substitutions(config, allow_unresolved)
|
||||
if has_unresolved and unresolved_value is MANDATORY_SUBSTITUTION:
|
||||
raise ConfigSubstitutionException('resolve cannot be set to True and unresolved_value to MANDATORY_SUBSTITUTION')
|
||||
raise ConfigSubstitutionException(
|
||||
'resolve cannot be set to True and unresolved_value to MANDATORY_SUBSTITUTION')
|
||||
|
||||
if unresolved_value is not NO_SUBSTITUTION and unresolved_value is not DEFAULT_SUBSTITUTION:
|
||||
cls.unresolve_substitutions_to_value(config, unresolved_value)
|
||||
@@ -485,14 +488,16 @@ class ConfigParser(object):
|
||||
if len(prop_path) > 1 and config.get(substitution.variable, None) is not None:
|
||||
continue # If value is present in latest version, don't do anything
|
||||
if prop_path[0] == key:
|
||||
if isinstance(previous_item, ConfigValues) and not accept_unresolved: # We hit a dead end, we cannot evaluate
|
||||
if isinstance(
|
||||
previous_item, ConfigValues) and not accept_unresolved: # We hit a dead end, we cannot evaluate
|
||||
raise ConfigSubstitutionException(
|
||||
"Property {variable} cannot be substituted. Check for cycles.".format(
|
||||
variable=substitution.variable
|
||||
)
|
||||
)
|
||||
else:
|
||||
value = previous_item if len(prop_path) == 1 else previous_item.get(".".join(prop_path[1:]))
|
||||
value = previous_item if len(
|
||||
prop_path) == 1 else previous_item.get(".".join(prop_path[1:]))
|
||||
_, _, current_item = cls._do_substitute(substitution, value)
|
||||
previous_item = current_item
|
||||
|
||||
@@ -632,7 +637,8 @@ class ConfigParser(object):
|
||||
# self resolution, backtrack
|
||||
resolved_value = substitution.parent.overriden_value
|
||||
|
||||
unresolved, new_substitutions, result = cls._do_substitute(substitution, resolved_value, is_optional_resolved)
|
||||
unresolved, new_substitutions, result = cls._do_substitute(
|
||||
substitution, resolved_value, is_optional_resolved)
|
||||
any_unresolved = unresolved or any_unresolved
|
||||
substitutions.extend(new_substitutions)
|
||||
if not isinstance(result, ConfigValues):
|
||||
|
||||
@@ -148,7 +148,8 @@ class ConfigTree(OrderedDict):
|
||||
|
||||
if elt is UndefinedKey:
|
||||
if default is UndefinedKey:
|
||||
raise ConfigMissingException(u"No configuration setting found for key {key}".format(key='.'.join(key_path[:key_index + 1])))
|
||||
raise ConfigMissingException(u"No configuration setting found for key {key}".format(
|
||||
key='.'.join(key_path[: key_index + 1])))
|
||||
else:
|
||||
return default
|
||||
|
||||
@@ -184,7 +185,9 @@ class ConfigTree(OrderedDict):
|
||||
return [string]
|
||||
|
||||
special_characters = '$}[]:=+#`^?!@*&.'
|
||||
tokens = re.findall(r'"[^"]+"|[^{special_characters}]+'.format(special_characters=re.escape(special_characters)), string)
|
||||
tokens = re.findall(
|
||||
r'"[^"]+"|[^{special_characters}]+'.format(special_characters=re.escape(special_characters)),
|
||||
string)
|
||||
|
||||
def contains_special_character(token):
|
||||
return any((c in special_characters) for c in token)
|
||||
|
||||
@@ -20,7 +20,7 @@ except NameError:
|
||||
class HOCONConverter(object):
|
||||
_number_re = r'[+-]?(\d*\.\d+|\d+(\.\d+)?)([eE][+\-]?\d+)?(?=$|[ \t]*([\$\}\],#\n\r]|//))'
|
||||
_number_re_matcher = re.compile(_number_re)
|
||||
|
||||
|
||||
@classmethod
|
||||
def to_json(cls, config, compact=False, indent=2, level=0):
|
||||
"""Convert HOCON input into a JSON output
|
||||
@@ -150,7 +150,7 @@ class HOCONConverter(object):
|
||||
new_line = cls.to_hocon(item, compact, indent, level + 1)
|
||||
lines += new_line
|
||||
if '\n' in new_line or len(lines) - base_len > 80:
|
||||
if i < len(config)-1:
|
||||
if i < len(config) - 1:
|
||||
lines += ',\n{indent}'.format(indent=''.rjust(level * indent, ' '))
|
||||
base_len = len(lines)
|
||||
skip_comma = True
|
||||
|
||||
@@ -45,7 +45,7 @@ class ResourceMonitor(object):
|
||||
else: # if running_remotely():
|
||||
try:
|
||||
active_gpus = os.environ.get('NVIDIA_VISIBLE_DEVICES', '') or \
|
||||
os.environ.get('CUDA_VISIBLE_DEVICES', '')
|
||||
os.environ.get('CUDA_VISIBLE_DEVICES', '')
|
||||
if active_gpus:
|
||||
self._active_gpus = [int(g.strip()) for g in active_gpus.split(',')]
|
||||
except Exception:
|
||||
@@ -150,7 +150,7 @@ class ResourceMonitor(object):
|
||||
try:
|
||||
title = self._title_gpu if k.startswith('gpu_') else self._title_machine
|
||||
# 3 points after the dot
|
||||
value = round(v*1000) / 1000.
|
||||
value = round(v * 1000) / 1000.
|
||||
self._task.get_logger().report_scalar(title=title, series=k, iteration=iteration, value=value)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -178,7 +178,7 @@ class ResourceMonitor(object):
|
||||
return self._num_readouts
|
||||
|
||||
def _get_average_readouts(self):
|
||||
average_readouts = dict((k, v/float(self._num_readouts)) for k, v in self._readouts.items())
|
||||
average_readouts = dict((k, v / float(self._num_readouts)) for k, v in self._readouts.items())
|
||||
return average_readouts
|
||||
|
||||
def _clear_readouts(self):
|
||||
@@ -205,7 +205,7 @@ class ResourceMonitor(object):
|
||||
self._get_process_used_memory() if self._process_info else virtual_memory.used) / 1024
|
||||
stats["memory_free_gb"] = bytes_to_megabytes(virtual_memory.available) / 1024
|
||||
disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent
|
||||
stats["disk_free_percent"] = 100.0-disk_use_percentage
|
||||
stats["disk_free_percent"] = 100.0 - disk_use_percentage
|
||||
with warnings.catch_warnings():
|
||||
if logging.root.level > logging.DEBUG: # If the logging level is bigger than debug, ignore
|
||||
# psutil.sensors_temperatures warnings
|
||||
|
||||
@@ -6,6 +6,7 @@ try:
|
||||
except Exception:
|
||||
np = None
|
||||
|
||||
|
||||
def make_deterministic(seed=1337, cudnn_deterministic=False):
|
||||
"""
|
||||
Ensure deterministic behavior across PyTorch using the provided random seed.
|
||||
|
||||
Reference in New Issue
Block a user