Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Erez Schanider 2019-07-21 08:48:35 +03:00
commit 94642af4a1
16 changed files with 522 additions and 256 deletions

View File

@ -1,7 +1,13 @@
# TRAINS SDK configuration file
api {
# Notice: 'host' is the api server (default port 8008), not the web server.
host: http://localhost:8008
# web_server on port 8080
web_server: "http://localhost:8080"
# Notice: 'api_server' is the api server (default port 8008), not the web server.
api_server: "http://localhost:8008"
# file server onport 8081
files_server: "http://localhost:8081"
# Credentials are generated in the webapp, http://localhost:8080/admin
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}

View File

@ -1,7 +1,8 @@
# TRAINS - Example of Matplotlib integration and reporting
# TRAINS - Example of Matplotlib and Seaborn integration and reporting
#
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from trains import Task
@ -33,4 +34,13 @@ plt.imshow(m)
plt.title('Image Title')
plt.show()
print('This is a Matplotlib example')
sns.set(style="darkgrid")
# Load an example dataset with long-form data
fmri = sns.load_dataset("fmri")
# Plot the responses for different events and regions
sns.lineplot(x="timepoint", y="signal",
hue="region", style="event",
data=fmri)
plt.show()
print('This is a Matplotlib & Seaborn example')

View File

@ -1,7 +1,11 @@
{
version: 1.5
# default https://demoapi.trainsai.io host
host: ""
# default api_server: https://demoapi.trainsai.io
api_server: ""
# default web_server: https://demoapp.trainsai.io
web_server: ""
# default files_server: https://demofiles.trainsai.io
files_server: ""
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True

View File

@ -2,6 +2,8 @@ from ...backend_config import EnvEntry
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "ALG_WEB_HOST")
ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "ALG_FILES_HOST")
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)

View File

@ -2,6 +2,7 @@ import json as json_lib
import sys
import types
from socket import gethostname
from six.moves.urllib.parse import urlparse, urlunparse
import jwt
import requests
@ -10,11 +11,11 @@ from pyhocon import ConfigTree
from requests.auth import HTTPBasicAuth
from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
from .request import Request, BatchRequest
from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ..version import __version__
@ -32,11 +33,13 @@ class Session(TokenManager):
_async_status_code = 202
_session_requests = 0
_session_initial_timeout = (1.0, 10)
_session_timeout = (5.0, None)
_session_initial_timeout = (3.0, 10.)
_session_timeout = (5.0, 300.)
api_version = '2.1'
default_host = "https://demoapi.trainsai.io"
default_web = "https://demoapp.trainsai.io"
default_files = "https://demofiles.trainsai.io"
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
@ -97,7 +100,7 @@ class Session(TokenManager):
self._logger = logger
self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=(self.config.get("api.credentials.access_key") or self.default_key)
default=(self.config.get("api.credentials.access_key", None) or self.default_key)
)
if not self.access_key:
raise ValueError(
@ -105,7 +108,7 @@ class Session(TokenManager):
)
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
default=(self.config.get("api.credentials.secret_key") or self.default_secret)
default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
)
if not self.secret_key:
raise ValueError(
@ -125,7 +128,7 @@ class Session(TokenManager):
self.__worker = worker or gethostname()
self.__max_req_size = self.config.get("api.http.max_req_size")
self.__max_req_size = self.config.get("api.http.max_req_size", None)
if not self.__max_req_size:
raise ValueError("missing max request size")
@ -140,6 +143,11 @@ class Session(TokenManager):
except (jwt.DecodeError, ValueError):
pass
# now setup the session reporting, so one consecutive retries will show warning
# we do that here, so if we have problems authenticating, we see them immediately
# notice: this is across the board warning omission
urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3)
def _send_request(
self,
service,
@ -394,7 +402,65 @@ class Session(TokenManager):
if not config:
from ...config import config_obj
config = config_obj
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
return ENV_HOST.get(default=(config.get("api.api_server", None) or
config.get("api.host", None) or cls.default_host))
@classmethod
def get_app_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
# get from config/environment
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
if web_host:
return web_host
# return default
host = cls.get_api_server_host(config)
if host == cls.default_host:
return cls.default_web
# compose ourselves
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host:
return host.replace('://api.', '://app.', 1)
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080', 1)
raise ValueError('Could not detect TRAINS web application server')
@classmethod
def get_files_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
# get from config/environment
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
if files_host:
return files_host
# return default
host = cls.get_api_server_host(config)
if host == cls.default_host:
return cls.default_files
# compose ourselves
app_host = cls.get_app_server_host(config)
parsed = urlparse(app_host)
if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else:
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
return urlunparse(parsed)
def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation.

View File

@ -27,6 +27,29 @@ def get_config():
return config_obj
def urllib_log_warning_setup(total_retries=10, display_warning_after=5):
class RetryFilter(logging.Filter):
last_instance = None
def __init__(self, total, warning_after=5):
super(RetryFilter, self).__init__()
self.total = total
self.display_warning_after = warning_after
self.last_instance = self
def filter(self, record):
if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry):
retry_left = self.total - record.args[0].total
return retry_left >= self.display_warning_after
return True
urllib3_log = logging.getLogger('urllib3.connectionpool')
if urllib3_log:
urllib3_log.removeFilter(RetryFilter.last_instance)
urllib3_log.addFilter(RetryFilter(total_retries, display_warning_after))
class TLSv1HTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
self.poolmanager = PoolManager(num_pools=connections,

View File

@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface):
if log:
log.error(error_msg)
if res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
if raise_on_errors:
raise SendError(res, error_msg)
return res
except requests.exceptions.BaseHTTPError as e:
log.error('failed sending %s: %s' % (str(req), str(e)))
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
except Exception as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
# Infrastructure error
if log:
log.info('retrying request %s' % str(req))
if res and res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
if raise_on_errors:
raise SendError(res, error_msg)
return res
# # Infrastructure error
# if log:
# log.info('retrying request %s' % str(req))
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,

View File

@ -1,8 +1,8 @@
import collections
import json
import cv2
import six
from threading import Thread, Event
from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin
@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._bucket_config = None
self._storage_uri = None
self._async_enable = async_enable
self._flush_frequency = 30.0
self._exit_flag = False
self._flush_event = Event()
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _set_storage_uri(self, value):
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
def async_enable(self, value):
self._async_enable = bool(value)
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency)
self._flush_event.clear()
self._write()
# wait for all reports
if self.get_num_results() > 0:
self.wait_for_results()
def _report(self, ev):
self._events.append(ev)
if len(self._events) >= self._flush_threshold:
self._write()
self.flush()
def _write(self):
if not self._events:
@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
"""
Flush cached reports to backend.
"""
self._write()
# wait for all reports
if self.get_num_results() > 0:
self.wait_for_results()
self._flush_event.set()
def stop(self):
self._exit_flag = True
self._flush_event.set()
self._thread.join()
def report_scalar(self, title, series, value, iter):
"""

View File

@ -77,6 +77,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
self._edit_lock = RLock()
super(Task, self).__init__(id=task_id, session=session, log=log)
self._project_name = None
self._storage_uri = None
self._input_model = None
self._output_model = None
@ -87,6 +88,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
)
self._app_server = None
self._files_server = None
if not task_id:
# generate a new task
@ -656,8 +659,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if self.project is None:
return None
if self._project_name and self._project_name[0] == self.project:
return self._project_name[1]
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
return res.response.project.name
self._project_name = (self.project, res.response.project.name)
return self._project_name[1]
def get_tags(self):
return self._get_task_property("tags")
@ -668,33 +675,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(tags=self.data.tags)
def _get_default_report_storage_uri(self):
app_host = self._get_app_server()
parsed = urlparse(app_host)
if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else:
parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed)
if not self._files_server:
self._files_server = Session.get_files_server_host()
return self._files_server
@classmethod
def _get_api_server(cls):
return Session.get_api_server_host()
@classmethod
def _get_app_server(cls):
host = cls._get_api_server()
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host:
return host.replace('://api.', '://app.', 1)
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080', 1)
def _get_app_server(self):
if not self._app_server:
self._app_server = Session.get_app_server_host()
return self._app_server
def _edit(self, **kwargs):
with self._edit_lock:

View File

@ -1,5 +1,7 @@
import os
import six
from ..config import TASK_LOG_ENVIRONMENT, running_remotely
@ -34,3 +36,43 @@ class EnvironmentBind(object):
if running_remotely():
# put back into os:
os.environ.update(env_param)
class PatchOsFork(object):
_original_fork = None
@classmethod
def patch_fork(cls):
# only once
if cls._original_fork:
return
if six.PY2:
cls._original_fork = staticmethod(os.fork)
else:
cls._original_fork = os.fork
os.fork = cls._patched_fork
@staticmethod
def _patched_fork(*args, **kwargs):
ret = PatchOsFork._original_fork(*args, **kwargs)
# Make sure the new process stdout is logged
if not ret:
from ..task import Task
if Task.current_task() is not None:
# bind sub-process logger
task = Task.init()
task.get_logger().flush()
# if we got here patch the os._exit of our instance to call us
def _at_exit_callback(*args, **kwargs):
# call at exit manually
# noinspection PyProtectedMember
task._at_exit()
# noinspection PyProtectedMember
return os._org_exit(*args, **kwargs)
if not hasattr(os, '_org_exit'):
os._org_exit = os._exit
os._exit = _at_exit_callback
return ret

View File

@ -1,4 +1,5 @@
import base64
import os
import sys
import threading
from collections import defaultdict
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into
Trains events and reports the events (metrics) for an Trains task (logger).
"""
_add_lock = threading.Lock()
_add_lock = threading.RLock()
_series_name_lookup = {}
# store all the created tensorboard writers in the system
# this allows us to as weather a certain tile/series already exist on some EventWriter
# and if it does, then we add to the series name the last token from the logdir
# (so we can differentiate between the two)
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
_title_series_writers_lookup = {}
_event_writers_id_to_logdir = {}
@property
def variants(self):
return self._variants
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
def prepare_report(self):
return self.variants.copy()
@staticmethod
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
logdir_header='series'):
"""
Split a tf.summary tag line to variant and metric.
Variant is the first part of the split tag, metric is the second.
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
:param str split_char: a character to split the tag on
:param str join_char: a character to join the the splits
:param str default_title: variant to use in case no variant can be inferred automatically
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
if 'title_last' then title=header title, if 'title' then title=title header
:return: (str, str) variant and metric
"""
splitted_tag = tag.split(split_char)
series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
# check if we already decided that we need to change the title/series
graph_id = hash((title, series))
if graph_id in self._graph_name_lookup:
return self._graph_name_lookup[graph_id]
# check if someone other than us used this combination
with self._add_lock:
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
if not event_writer_id:
# put us there
self._title_series_writers_lookup[graph_id] = self._id
elif event_writer_id != self._id:
# if there is someone else, change our series name and store us
org_series = series
org_title = title
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
split_logddir = self._logdir.split(os.path.sep)
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
header = '/'.join(s for s in split_logddir if s in unique_logdir)
if logdir_header == 'series_last':
series = header + ': ' + series
elif logdir_header == 'series':
series = series + ' :' + header
elif logdir_header == 'title':
title = title + ' ' + header
else: # logdir_header == 'title_last':
title = header + ' ' + title
graph_id = hash((title, series))
# check if for some reason the new series is already occupied
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
if new_event_writer_id is not None and new_event_writer_id != self._id:
# well that's about it, nothing else we could do
if logdir_header == 'series_last':
series = str(self._logdir) + ': ' + org_series
elif logdir_header == 'series':
series = org_series + ' :' + str(self._logdir)
elif logdir_header == 'title':
title = org_title + ' ' + str(self._logdir)
else: # logdir_header == 'title_last':
title = str(self._logdir) + ' ' + org_title
graph_id = hash((title, series))
self._title_series_writers_lookup[graph_id] = self._id
# store for next time
self._graph_name_lookup[graph_id] = (title, series)
return title, series
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
histogram_granularity=50, max_keep_images=None):
def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
"""
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
"""
# We are the events_writer, so that's what we'll pass
IsTensorboardInit.set_tensorboard_used()
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
self._id = hash(self._logdir)
self._event_writers_id_to_logdir[self._id] = self._logdir
self.max_keep_images = max_keep_images
self.report_freq = report_freq
self.image_report_freq = image_report_freq if image_report_freq else report_freq
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
self._hist_report_cache = {}
self._hist_x_granularity = 50
self._max_step = 0
self._graph_name_lookup = {}
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
if img_data_np is None:
return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
if img_data_np.dtype != np.uint8:
# assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8)
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
def _add_scalar(self, tag, step, scalar_data):
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
# update scalar cache
num, value = self._scalar_report_cache.get((title, series), (0, 0))
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
# Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
logdir_header='series')
# get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
self.trains.add_event(*args, **kwargs)
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict)
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
self.trains.add_event(*args, **kwargs)
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
# patch the events writer field, and add a double Event Logger (Trains and original)
base_eventwriter = __dict__['event_writer']
try:
logdir = base_eventwriter.get_logdir()
except Exception:
logdir = None
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
logdir=logdir, **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod
def _get_event_writer():
def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task:
return None
if PatchTensorFlowEager.__trains_event_writer is None:
try:
logdir = writer.get_logdir()
except Exception:
logdir = None
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer
@staticmethod
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer()
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer:
try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
@ -1350,93 +1431,3 @@ class PatchTensorflowModelIO(object):
pass
return model
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model

View File

@ -11,19 +11,20 @@ from trains.config import config_obj
description = """
Please create new credentials using the web app: {}/admin
Please create new credentials using the web app: {}/profile
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
Paste credentials here: """
try:
def_host = ENV_HOST.get(default=config_obj.get("api.host"))
def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080'
except Exception:
def_host = 'http://localhost:8080'
host_description = """
Editing configuration file: {CONFIG_FILE}
Enter the url of the trains-server's api service, for example: http://localhost:8008 : """.format(
Enter the url of the trains-server's Web service, for example: {HOST}
""".format(
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
HOST=def_host,
)
@ -37,64 +38,60 @@ def main():
print('Leaving setup, feel free to edit the configuration file.')
return
print(host_description, end='')
parsed_host = None
while not parsed_host:
parse_input = input()
if not parse_input:
parse_input = def_host
# noinspection PyBroadException
try:
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
parse_input = 'http://'+parse_input
parsed_host = urlparse(parse_input)
if parsed_host.scheme not in ('http', 'https'):
parsed_host = None
except Exception:
parsed_host = None
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
print(host_description)
web_host = input_url('Web Application Host', '')
parsed_host = verify_url(web_host)
if parsed_host.port == 8080:
# this is a docker 8080 is the web address, we need the api address, it is 8008
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
if parsed_host.port == 8008:
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'):
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
parsed_host.netloc))
# this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('app.'):
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
parsed_host.netloc))
# this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.port == 8008:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapi.'):
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('api.'):
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
else:
api_host = None
web_host = None
api_host = ''
web_host = ''
files_host = ''
if not parsed_host.port:
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
replace_port = input().lower()
if not replace_port or replace_port == 'y' or replace_port == 'yes':
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
if not api_host:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
if not web_host:
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
print('Host configured to: {}'.format(api_host))
api_host = input_url('API Host', api_host)
files_host = input_url('File Store Host', files_host)
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
api_host, web_host, files_host))
print(description.format(web_host), end='')
parse_input = input()
@ -133,11 +130,14 @@ def main():
header = '# TRAINS SDK configuration file\n' \
'api {\n' \
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
' host: %s\n' \
' # Credentials are generated in the webapp, %s/admin\n' \
' api_server: %s\n' \
' web_server: %s\n' \
' files_server: %s\n' \
' # Credentials are generated in the webapp, %s/profile\n' \
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
'}\n' \
'sdk ' % (api_host, web_host, credentials['access_key'], credentials['secret_key'])
'sdk ' % (api_host, web_host, files_host,
web_host, credentials['access_key'], credentials['secret_key'])
f.write(header)
f.write(default_sdk)
except Exception:
@ -148,5 +148,30 @@ def main():
print('TRAINS setup completed successfully.')
def input_url(host_type, host=None):
while True:
print('{} configured to: [{}] '.format(host_type, host), end='')
parse_input = input()
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
break
if parse_input and verify_url(parse_input):
host = parse_input
break
return host
def verify_url(parse_input):
try:
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
parse_input = 'http://' + parse_input
parsed_host = urlparse(parse_input)
if parsed_host.scheme not in ('http', 'https'):
parsed_host = None
except Exception:
parsed_host = None
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
return parsed_host
if __name__ == '__main__':
main()

View File

@ -81,11 +81,14 @@ class Logger(object):
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
# noinspection PyBroadException
try:
Logger._stdout_original_write = sys.stdout.write
if Logger._stdout_original_write is None:
Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, guard it with try/catch
sys.stdout._original_write = sys.stdout.write
if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__
sys.stderr._original_write = sys.stderr.write
if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = stderr__patched__write__
except Exception:
pass
@ -113,6 +116,7 @@ class Logger(object):
msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO
# noinspection PyBroadException
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
@ -128,6 +132,7 @@ class Logger(object):
if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout:
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
Logger._stdout_original_write(str(msg)+'\n')
@ -637,11 +642,13 @@ class Logger(object):
@classmethod
def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stderr.connect(None)
except Exception:
@ -711,7 +718,13 @@ class PrintPatchLogger(object):
if cur_line:
with PrintPatchLogger.recursion_protect_lock:
self._log.console(cur_line, level=self._log_level, omit_console=True)
# noinspection PyBroadException
try:
if self._log:
self._log.console(cur_line, level=self._log_level, omit_console=True)
except Exception:
# what can we do, nothing
pass
else:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
@ -719,8 +732,7 @@ class PrintPatchLogger(object):
self._terminal.write(message)
def connect(self, logger):
if self._log:
self._log._flush_stdout_handler()
self._cur_line = ''
self._log = logger
def __getattr__(self, attr):

View File

@ -26,7 +26,7 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .binding.environ_bind import EnvironmentBind
from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
@ -66,6 +66,7 @@ class Task(_Task):
__create_protection = object()
__main_task = None
__exit_hook = None
__forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
@ -104,7 +105,6 @@ class Task(_Task):
self._resource_monitor = None
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
self.__register_at_exit(self._at_exit)
@classmethod
def current_task(cls):
@ -132,9 +132,10 @@ class Task(_Task):
:param project_name: project to create the task in (if project doesn't exist, it will be created)
:param task_name: task name to be created (in development mode, not when running remotely)
:param task_type: task type to be created (in development mode, not when running remotely)
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \
if False every time we call the function we create a new task with the same name \
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
if False every time we call the function we create a new task with the same name
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
Note: A closed or published task will not be reused, and a new task will be created.
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
notice: sub-folders (task_id) is created in the destination folder for all outputs.
@ -166,12 +167,31 @@ class Task(_Task):
)
if cls.__main_task is not None:
# if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure we do not wait for the repo detect thread
cls.__main_task._detect_repo_async_thread = None
# remove the logger from the previous process
logger = cls.__main_task.get_logger()
logger.set_flush_period(None)
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
cls.__main_task._reporter = None
cls.__main_task.get_logger()
# unregister signal hooks, they cause subprocess to hang
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
if not running_remotely():
verify_defaults_match()
return cls.__main_task
# check that we are not a child process, in that case do nothing
# check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
class _TaskStub(object):
def __call__(self, *args, **kwargs):
@ -212,9 +232,10 @@ class Task(_Task):
raise
else:
Task.__main_task = task
# Patch argparse to be aware of the current task
argparser_update_currenttask(Task.__main_task)
EnvironmentBind.update_current_task(Task.__main_task)
# register the main task for at exit hooks (there should only be one)
task.__register_at_exit(task._at_exit)
# patch OS forking
PatchOsFork.patch_fork()
if auto_connect_frameworks:
PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task)
@ -227,21 +248,19 @@ class Task(_Task):
if auto_resource_monitoring:
task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start()
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# make sure all random generators are initialized with new seed
make_deterministic(task.get_random_seed())
if auto_connect_arg_parser:
EnvironmentBind.update_current_task(Task.__main_task)
# Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task)
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser, parsed_args=parsed_args)
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
@ -339,7 +358,9 @@ class Task(_Task):
in_dev_mode = not running_remotely()
if in_dev_mode:
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
closed_old_task = cls.__close_timed_out_task(default_task)
else:
@ -600,6 +621,9 @@ class Task(_Task):
"""
self._at_exit()
self._at_exit_called = False
# unregister atexit callbacks and signal hooks, if we are the main task
if self.is_main_task():
self.__register_at_exit(None)
def is_current_task(self):
"""
@ -914,9 +938,12 @@ class Task(_Task):
Will happen automatically once we exit code, i.e. atexit
:return:
"""
# protect sub-process at_exit
if self._at_exit_called:
return
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
# noinspection PyBroadException
try:
# from here do not get into watch dog
@ -948,28 +975,32 @@ class Task(_Task):
# from here, do not send log in background thread
if wait_for_uploads:
self.flush(wait_for_uploads=True)
# wait until the reporter flush everything
self.reporter.stop()
if print_done_waiting:
self.log.info('Finished uploading')
else:
self._logger._flush_stdout_handler()
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
if not is_sub_process:
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# stop resource monitoring
if self._resource_monitor:
self._resource_monitor.stop()
self._logger.set_flush_period(None)
# this is so in theory we can close a main task and start a new one
Task.__main_task = None
@ -978,7 +1009,7 @@ class Task(_Task):
pass
@classmethod
def __register_at_exit(cls, exit_callback):
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
class ExitHooks(object):
_orig_exit = None
_orig_exc_handler = None
@ -1000,7 +1031,21 @@ class Task(_Task):
except Exception:
pass
self._exit_callback = callback
atexit.register(self._exit_callback)
if callback:
self.hook()
else:
# un register int hook
print('removing int hook', self._orig_exc_handler)
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for s in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, self._org_handlers[s])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
@ -1009,20 +1054,23 @@ class Task(_Task):
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
atexit.register(self._exit_callback)
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
if self._exit_callback:
atexit.register(self._exit_callback)
if self._org_handlers:
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
def exit(self, code=0):
self.exit_code = code
@ -1077,6 +1125,22 @@ class Task(_Task):
# return handler result
return org_handler
# we only remove the signals since this will hang subprocesses
if only_remove_signal_and_exception_hooks:
if not cls.__exit_hook:
return
if cls.__exit_hook._orig_exc_handler:
sys.excepthook = cls.__exit_hook._orig_exc_handler
cls.__exit_hook._orig_exc_handler = None
for s in cls.__exit_hook._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, cls.__exit_hook._org_handlers[s])
except Exception:
pass
cls.__exit_hook._org_handlers = {}
return
if cls.__exit_hook is None:
# noinspection PyBroadException
try:
@ -1084,13 +1148,13 @@ class Task(_Task):
cls.__exit_hook.hook()
except Exception:
cls.__exit_hook = None
elif cls.__main_task is None:
else:
cls.__exit_hook.update_callback(exit_callback)
@classmethod
def __get_task(cls, task_id=None, project_name=None, task_name=None):
if task_id:
return cls(private=cls.__create_protection, task_id=task_id)
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
res = cls._send(
cls._get_default_session(),

View File

@ -1,3 +1,4 @@
import os
import time
from threading import Lock
@ -6,7 +7,8 @@ import six
class AsyncManagerMixin(object):
_async_results_lock = Lock()
_async_results = []
# per pid (process) list of async jobs (support for sub-processes forking)
_async_results = {}
@classmethod
def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None):
@ -14,8 +16,9 @@ class AsyncManagerMixin(object):
try:
cls._async_results_lock.acquire()
# discard completed results
cls._async_results = [r for r in cls._async_results if not r.ready()]
num_results = len(cls._async_results)
pid = os.getpid()
cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()]
num_results = len(cls._async_results[pid])
if wait_on_max_results is not None and num_results >= wait_on_max_results:
# At least max_results results are still pending, wait
if wait_cb:
@ -25,7 +28,9 @@ class AsyncManagerMixin(object):
continue
# add result
if result and not result.ready():
cls._async_results.append(result)
if not cls._async_results.get(pid):
cls._async_results[pid] = []
cls._async_results[pid].append(result)
break
finally:
cls._async_results_lock.release()
@ -34,7 +39,8 @@ class AsyncManagerMixin(object):
def wait_for_results(cls, timeout=None, max_num_uploads=None):
remaining = timeout
count = 0
for r in cls._async_results:
pid = os.getpid()
for r in cls._async_results.get(pid, []):
if r.ready():
continue
t = time.time()
@ -48,13 +54,14 @@ class AsyncManagerMixin(object):
if max_num_uploads is not None and max_num_uploads - count <= 0:
break
if timeout is not None:
remaining = max(0, remaining - max(0, time.time() - t))
remaining = max(0., remaining - max(0., time.time() - t))
if not remaining:
break
@classmethod
def get_num_results(cls):
if cls._async_results is not None:
return len([r for r in cls._async_results if not r.ready()])
pid = os.getpid()
if cls._async_results.get(pid, []):
return len([r for r in cls._async_results.get(pid, []) if not r.ready()])
else:
return 0

View File

@ -1 +1 @@
__version__ = '0.10.2'
__version__ = '0.10.3rc1'