From 51cc50e2391144c66c647c0d0075a6055ddd8f39 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 22:01:27 +0300 Subject: [PATCH 1/7] Add separate api/web/file server configuration (backward support included). OS environment override with: TRAINS_API_HOST / TRAINS_WEB_HOST / TRAINS_FILES_HOST --- docs/trains.conf | 10 ++- trains/backend_api/config/default/api.conf | 8 +- trains/backend_api/session/defs.py | 2 + trains/backend_api/session/session.py | 82 ++++++++++++++++-- trains/backend_api/utils.py | 23 +++++ trains/config/default/__main__.py | 99 ++++++++++++++-------- 6 files changed, 175 insertions(+), 49 deletions(-) diff --git a/docs/trains.conf b/docs/trains.conf index a27a3948..7e12205e 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -1,7 +1,13 @@ # TRAINS SDK configuration file api { - # Notice: 'host' is the api server (default port 8008), not the web server. - host: http://localhost:8008 + # web_server on port 8080 + web_server: "http://localhost:8080" + + # Notice: 'api_server' is the api server (default port 8008), not the web server. + api_server: "http://localhost:8008" + + # file server onport 8081 + files_server: "http://localhost:8081" # Credentials are generated in the webapp, http://localhost:8080/admin credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"} diff --git a/trains/backend_api/config/default/api.conf b/trains/backend_api/config/default/api.conf index cbac9189..de96becc 100644 --- a/trains/backend_api/config/default/api.conf +++ b/trains/backend_api/config/default/api.conf @@ -1,7 +1,11 @@ { version: 1.5 - # default https://demoapi.trainsai.io host - host: "" + # default api_server: https://demoapi.trainsai.io + api_server: "" + # default web_server: https://demoapp.trainsai.io + web_server: "" + # default files_server: https://demofiles.trainsai.io + files_server: "" # verify host ssl certificate, set to False only if you have a very good reason verify_certificate: True diff --git a/trains/backend_api/session/defs.py b/trains/backend_api/session/defs.py index 73fc6603..5ea6a97d 100644 --- a/trains/backend_api/session/defs.py +++ b/trains/backend_api/session/defs.py @@ -2,6 +2,8 @@ from ...backend_config import EnvEntry ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST") +ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "ALG_WEB_HOST") +ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "ALG_FILES_HOST") ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY") ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY") ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False) diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index da71098e..bdb71aa0 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -2,6 +2,7 @@ import json as json_lib import sys import types from socket import gethostname +from six.moves.urllib.parse import urlparse, urlunparse import jwt import requests @@ -10,11 +11,11 @@ from pyhocon import ConfigTree from requests.auth import HTTPBasicAuth from .callresult import CallResult -from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY +from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST from .request import Request, BatchRequest from .token_manager import TokenManager from ..config import load -from ..utils import get_http_session_with_retry +from ..utils import get_http_session_with_retry, urllib_log_warning_setup from ..version import __version__ @@ -32,11 +33,13 @@ class Session(TokenManager): _async_status_code = 202 _session_requests = 0 - _session_initial_timeout = (1.0, 10) - _session_timeout = (5.0, None) + _session_initial_timeout = (3.0, 10.) + _session_timeout = (5.0, 300.) api_version = '2.1' default_host = "https://demoapi.trainsai.io" + default_web = "https://demoapp.trainsai.io" + default_files = "https://demofiles.trainsai.io" default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" @@ -97,7 +100,7 @@ class Session(TokenManager): self._logger = logger self.__access_key = api_key or ENV_ACCESS_KEY.get( - default=(self.config.get("api.credentials.access_key") or self.default_key) + default=(self.config.get("api.credentials.access_key", None) or self.default_key) ) if not self.access_key: raise ValueError( @@ -105,7 +108,7 @@ class Session(TokenManager): ) self.__secret_key = secret_key or ENV_SECRET_KEY.get( - default=(self.config.get("api.credentials.secret_key") or self.default_secret) + default=(self.config.get("api.credentials.secret_key", None) or self.default_secret) ) if not self.secret_key: raise ValueError( @@ -125,7 +128,7 @@ class Session(TokenManager): self.__worker = worker or gethostname() - self.__max_req_size = self.config.get("api.http.max_req_size") + self.__max_req_size = self.config.get("api.http.max_req_size", None) if not self.__max_req_size: raise ValueError("missing max request size") @@ -140,6 +143,11 @@ class Session(TokenManager): except (jwt.DecodeError, ValueError): pass + # now setup the session reporting, so one consecutive retries will show warning + # we do that here, so if we have problems authenticating, we see them immediately + # notice: this is across the board warning omission + urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3) + def _send_request( self, service, @@ -394,7 +402,65 @@ class Session(TokenManager): if not config: from ...config import config_obj config = config_obj - return ENV_HOST.get(default=(config.get("api.host") or cls.default_host)) + return ENV_HOST.get(default=(config.get("api.api_server", None) or + config.get("api.host", None) or cls.default_host)) + + @classmethod + def get_app_server_host(cls, config=None): + if not config: + from ...config import config_obj + config = config_obj + + # get from config/environment + web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None)) + if web_host: + return web_host + + # return default + host = cls.get_api_server_host(config) + if host == cls.default_host: + return cls.default_web + + # compose ourselves + if '://demoapi.' in host: + return host.replace('://demoapi.', '://demoapp.', 1) + if '://api.' in host: + return host.replace('://api.', '://app.', 1) + + parsed = urlparse(host) + if parsed.port == 8008: + return host.replace(':8008', ':8080', 1) + + raise ValueError('Could not detect TRAINS web application server') + + @classmethod + def get_files_server_host(cls, config=None): + if not config: + from ...config import config_obj + config = config_obj + # get from config/environment + files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None))) + if files_host: + return files_host + + # return default + host = cls.get_api_server_host(config) + if host == cls.default_host: + return cls.default_files + + # compose ourselves + app_host = cls.get_app_server_host(config) + parsed = urlparse(app_host) + if parsed.port: + parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1)) + elif parsed.netloc.startswith('demoapp.'): + parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1)) + elif parsed.netloc.startswith('app.'): + parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1)) + else: + parsed = parsed._replace(netloc=parsed.netloc + ':8081') + + return urlunparse(parsed) def _do_refresh_token(self, old_token, exp=None): """ TokenManager abstract method implementation. diff --git a/trains/backend_api/utils.py b/trains/backend_api/utils.py index da25e129..7bf47c07 100644 --- a/trains/backend_api/utils.py +++ b/trains/backend_api/utils.py @@ -27,6 +27,29 @@ def get_config(): return config_obj +def urllib_log_warning_setup(total_retries=10, display_warning_after=5): + class RetryFilter(logging.Filter): + last_instance = None + + def __init__(self, total, warning_after=5): + super(RetryFilter, self).__init__() + self.total = total + self.display_warning_after = warning_after + self.last_instance = self + + def filter(self, record): + if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry): + retry_left = self.total - record.args[0].total + return retry_left >= self.display_warning_after + + return True + + urllib3_log = logging.getLogger('urllib3.connectionpool') + if urllib3_log: + urllib3_log.removeFilter(RetryFilter.last_instance) + urllib3_log.addFilter(RetryFilter(total_retries, display_warning_after)) + + class TLSv1HTTPAdapter(HTTPAdapter): def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): self.poolmanager = PoolManager(num_pools=connections, diff --git a/trains/config/default/__main__.py b/trains/config/default/__main__.py index aa450008..4cdd52c0 100644 --- a/trains/config/default/__main__.py +++ b/trains/config/default/__main__.py @@ -11,19 +11,20 @@ from trains.config import config_obj description = """ -Please create new credentials using the web app: {}/admin +Please create new credentials using the web app: {}/profile In the Admin page, press "Create new credentials", then press "Copy to clipboard" Paste credentials here: """ try: - def_host = ENV_HOST.get(default=config_obj.get("api.host")) + def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080' except Exception: def_host = 'http://localhost:8080' host_description = """ Editing configuration file: {CONFIG_FILE} -Enter the url of the trains-server's api service, for example: http://localhost:8008 : """.format( +Enter the url of the trains-server's Web service, for example: {HOST} +""".format( CONFIG_FILE=LOCAL_CONFIG_FILES[0], HOST=def_host, ) @@ -37,64 +38,60 @@ def main(): print('Leaving setup, feel free to edit the configuration file.') return - print(host_description, end='') - parsed_host = None - while not parsed_host: - parse_input = input() - if not parse_input: - parse_input = def_host - # noinspection PyBroadException - try: - if not parse_input.startswith('http://') and not parse_input.startswith('https://'): - parse_input = 'http://'+parse_input - parsed_host = urlparse(parse_input) - if parsed_host.scheme not in ('http', 'https'): - parsed_host = None - except Exception: - parsed_host = None - print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='') + print(host_description) + web_host = input_url('Web Application Host', '') + parsed_host = verify_url(web_host) - if parsed_host.port == 8080: - # this is a docker 8080 is the web address, we need the api address, it is 8008 - print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008') + if parsed_host.port == 8008: + print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application') + api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path + elif parsed_host.port == 8080: api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path elif parsed_host.netloc.startswith('demoapp.'): - print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format( - parsed_host.netloc)) # this is our demo server api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path elif parsed_host.netloc.startswith('app.'): - print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format( - parsed_host.netloc)) # this is our application server api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - elif parsed_host.port == 8008: - api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path elif parsed_host.netloc.startswith('demoapi.'): + print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format( + parsed_host.netloc)) api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path elif parsed_host.netloc.startswith('api.'): + print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format( + parsed_host.netloc)) api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path else: - api_host = None - web_host = None + api_host = '' + web_host = '' + files_host = '' if not parsed_host.port: print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='') replace_port = input().lower() if not replace_port or replace_port == 'y' or replace_port == 'yes': api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path + files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path if not api_host: api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - if not web_host: - web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - print('Host configured to: {}'.format(api_host)) + api_host = input_url('API Host', api_host) + files_host = input_url('File Store Host', files_host) + + print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format( + api_host, web_host, files_host)) print(description.format(web_host), end='') parse_input = input() @@ -133,11 +130,14 @@ def main(): header = '# TRAINS SDK configuration file\n' \ 'api {\n' \ ' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \ - ' host: %s\n' \ - ' # Credentials are generated in the webapp, %s/admin\n' \ + ' api_server: %s\n' \ + ' web_server: %s\n' \ + ' files_server: %s\n' \ + ' # Credentials are generated in the webapp, %s/profile\n' \ ' credentials {"access_key": "%s", "secret_key": "%s"}\n' \ '}\n' \ - 'sdk ' % (api_host, web_host, credentials['access_key'], credentials['secret_key']) + 'sdk ' % (api_host, web_host, files_host, + web_host, credentials['access_key'], credentials['secret_key']) f.write(header) f.write(default_sdk) except Exception: @@ -148,5 +148,30 @@ def main(): print('TRAINS setup completed successfully.') +def input_url(host_type, host=None): + while True: + print('{} configured to: [{}] '.format(host_type, host), end='') + parse_input = input() + if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'): + break + if parse_input and verify_url(parse_input): + host = parse_input + break + return host + + +def verify_url(parse_input): + try: + if not parse_input.startswith('http://') and not parse_input.startswith('https://'): + parse_input = 'http://' + parse_input + parsed_host = urlparse(parse_input) + if parsed_host.scheme not in ('http', 'https'): + parsed_host = None + except Exception: + parsed_host = None + print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='') + return parsed_host + + if __name__ == '__main__': main() From 00f873081a2d4b83c64f45a1c8091227da8283a7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 22:02:19 +0300 Subject: [PATCH 2/7] Add support for multiple event writers in the same session --- trains/binding/frameworks/tensorflow_bind.py | 203 +++++++++---------- 1 file changed, 97 insertions(+), 106 deletions(-) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index 1f8697cd..55c6fc1b 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -1,4 +1,5 @@ import base64 +import os import sys import threading from collections import defaultdict @@ -44,9 +45,17 @@ class EventTrainsWriter(object): TF SummaryWriter implementation that converts the tensorboard's summary into Trains events and reports the events (metrics) for an Trains task (logger). """ - _add_lock = threading.Lock() + _add_lock = threading.RLock() _series_name_lookup = {} + # store all the created tensorboard writers in the system + # this allows us to as weather a certain tile/series already exist on some EventWriter + # and if it does, then we add to the series name the last token from the logdir + # (so we can differentiate between the two) + # key, value: key=hash(title, graph), value=EventTrainsWriter._id + _title_series_writers_lookup = {} + _event_writers_id_to_logdir = {} + @property def variants(self): return self._variants @@ -54,8 +63,8 @@ class EventTrainsWriter(object): def prepare_report(self): return self.variants.copy() - @staticmethod - def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'): + def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant', + logdir_header='series'): """ Split a tf.summary tag line to variant and metric. Variant is the first part of the split tag, metric is the second. @@ -64,15 +73,64 @@ class EventTrainsWriter(object): :param str split_char: a character to split the tag on :param str join_char: a character to join the the splits :param str default_title: variant to use in case no variant can be inferred automatically + :param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header, + if 'title_last' then title=header title, if 'title' then title=title header :return: (str, str) variant and metric """ splitted_tag = tag.split(split_char) series = join_char.join(splitted_tag[-num_split_parts:]) title = join_char.join(splitted_tag[:-num_split_parts]) or default_title + + # check if we already decided that we need to change the title/series + graph_id = hash((title, series)) + if graph_id in self._graph_name_lookup: + return self._graph_name_lookup[graph_id] + + # check if someone other than us used this combination + with self._add_lock: + event_writer_id = self._title_series_writers_lookup.get(graph_id, None) + if not event_writer_id: + # put us there + self._title_series_writers_lookup[graph_id] = self._id + elif event_writer_id != self._id: + # if there is someone else, change our series name and store us + org_series = series + org_title = title + other_logdir = self._event_writers_id_to_logdir[event_writer_id] + split_logddir = self._logdir.split(os.path.sep) + unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep)) + header = '/'.join(s for s in split_logddir if s in unique_logdir) + if logdir_header == 'series_last': + series = header + ': ' + series + elif logdir_header == 'series': + series = series + ' :' + header + elif logdir_header == 'title': + title = title + ' ' + header + else: # logdir_header == 'title_last': + title = header + ' ' + title + graph_id = hash((title, series)) + # check if for some reason the new series is already occupied + new_event_writer_id = self._title_series_writers_lookup.get(graph_id) + if new_event_writer_id is not None and new_event_writer_id != self._id: + # well that's about it, nothing else we could do + if logdir_header == 'series_last': + series = str(self._logdir) + ': ' + org_series + elif logdir_header == 'series': + series = org_series + ' :' + str(self._logdir) + elif logdir_header == 'title': + title = org_title + ' ' + str(self._logdir) + else: # logdir_header == 'title_last': + title = str(self._logdir) + ' ' + org_title + graph_id = hash((title, series)) + + self._title_series_writers_lookup[graph_id] = self._id + + # store for next time + self._graph_name_lookup[graph_id] = (title, series) return title, series - def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10, - histogram_granularity=50, max_keep_images=None): + def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None, + histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None): """ Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter @@ -87,6 +145,9 @@ class EventTrainsWriter(object): """ # We are the events_writer, so that's what we'll pass IsTensorboardInit.set_tensorboard_used() + self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir)) + self._id = hash(self._logdir) + self._event_writers_id_to_logdir[self._id] = self._logdir self.max_keep_images = max_keep_images self.report_freq = report_freq self.image_report_freq = image_report_freq if image_report_freq else report_freq @@ -99,6 +160,7 @@ class EventTrainsWriter(object): self._hist_report_cache = {} self._hist_x_granularity = 50 self._max_step = 0 + self._graph_name_lookup = {} def _decode_image(self, img_str, width, height, color_channels): # noinspection PyBroadException @@ -131,7 +193,7 @@ class EventTrainsWriter(object): if img_data_np is None: return - title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images') + title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title') if img_data_np.dtype != np.uint8: # assume scale 0-1 img_data_np = (img_data_np * 255).astype(np.uint8) @@ -168,7 +230,7 @@ class EventTrainsWriter(object): return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix) def _add_scalar(self, tag, step, scalar_data): - title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars') + title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last') # update scalar cache num, value = self._scalar_report_cache.get((title, series), (0, 0)) @@ -216,7 +278,8 @@ class EventTrainsWriter(object): # Y-axis (rows) is iteration (from 0 to current Step) # X-axis averaged bins (conformed sample 'bucketLimit') # Z-axis actual value (interpolated 'bucket') - title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms') + title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms', + logdir_header='series') # get histograms from cache hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) @@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object): if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) if not self.trains: + try: + logdir = self.get_logdir() + except Exception: + logdir = None self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), - **PatchSummaryToEventTransformer.defaults_dict) + logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) # noinspection PyBroadException try: self.trains.add_event(*args, **kwargs) @@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object): if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) if not self.trains: + try: + logdir = self.get_logdir() + except Exception: + logdir = None self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), - **PatchSummaryToEventTransformer.defaults_dict) + logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) # noinspection PyBroadException try: self.trains.add_event(*args, **kwargs) @@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object): # patch the events writer field, and add a double Event Logger (Trains and original) base_eventwriter = __dict__['event_writer'] + try: + logdir = base_eventwriter.get_logdir() + except Exception: + logdir = None defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict - trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict) + trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), + logdir=logdir, **defaults_dict) # order is important, the return value of ProxyEventsWriter is the last object in the list __dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter]) @@ -798,12 +874,17 @@ class PatchTensorFlowEager(object): getLogger(TrainsFrameworkAdapter).debug(str(ex)) @staticmethod - def _get_event_writer(): + def _get_event_writer(writer): if not PatchTensorFlowEager.__main_task: return None if PatchTensorFlowEager.__trains_event_writer is None: + try: + logdir = writer.get_logdir() + except Exception: + logdir = None PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter( - logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict) + logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir, + **PatchTensorFlowEager.defaults_dict) return PatchTensorFlowEager.__trains_event_writer @staticmethod @@ -812,7 +893,7 @@ class PatchTensorFlowEager(object): @staticmethod def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): - event_writer = PatchTensorFlowEager._get_event_writer() + event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) @@ -822,7 +903,7 @@ class PatchTensorFlowEager(object): @staticmethod def _write_hist_summary(writer, step, tag, values, name, **kwargs): - event_writer = PatchTensorFlowEager._get_event_writer() + event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) @@ -832,7 +913,7 @@ class PatchTensorFlowEager(object): @staticmethod def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): - event_writer = PatchTensorFlowEager._get_event_writer() + event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer: try: event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), @@ -1351,93 +1432,3 @@ class PatchTensorflowModelIO(object): pass return model - -class PatchPyTorchModelIO(object): - __main_task = None - __patched = None - - @staticmethod - def update_current_task(task, **kwargs): - PatchPyTorchModelIO.__main_task = task - PatchPyTorchModelIO._patch_model_io() - PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) - - @staticmethod - def _patch_model_io(): - if PatchPyTorchModelIO.__patched: - return - - if 'torch' not in sys.modules: - return - - PatchPyTorchModelIO.__patched = True - # noinspection PyBroadException - try: - # hack: make sure tensorflow.__init__ is called - import torch - torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) - torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) - except ImportError: - pass - except Exception: - pass # print('Failed patching pytorch') - - @staticmethod - def _save(original_fn, obj, f, *args, **kwargs): - ret = original_fn(obj, f, *args, **kwargs) - if not PatchPyTorchModelIO.__main_task: - return ret - - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - # noinspection PyBroadException - try: - f.flush() - except Exception: - pass - else: - filename = None - - # give the model a descriptive name based on the file name - # noinspection PyBroadException - try: - model_name = Path(filename).stem - except Exception: - model_name = None - WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, - singlefile=True, model_name=model_name) - return ret - - @staticmethod - def _load(original_fn, f, *args, **kwargs): - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - else: - filename = None - - if not PatchPyTorchModelIO.__main_task: - return original_fn(f, *args, **kwargs) - - # register input model - empty = _Empty() - if running_remotely(): - filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, - PatchPyTorchModelIO.__main_task) - model = original_fn(filename or f, *args, **kwargs) - else: - # try to load model before registering, in case we fail - model = original_fn(filename or f, *args, **kwargs) - WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch, - PatchPyTorchModelIO.__main_task) - - if empty.trains_in_model: - # noinspection PyBroadException - try: - model.trains_in_model = empty.trains_in_model - except Exception: - pass - return model From 50ce49a3dd87389d22d43cf24292fff888f6915c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 23:09:44 +0300 Subject: [PATCH 3/7] Add separate api/web/file server configuration (backward support included). OS environment override with: TRAINS_API_HOST / TRAINS_WEB_HOST / TRAINS_FILES_HOST --- trains/backend_interface/task/task.py | 38 +++++++++++---------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index f207f1d9..15a45d29 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -77,6 +77,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): task_id = self._resolve_task_id(task_id, log=log) if not force_create else None self._edit_lock = RLock() super(Task, self).__init__(id=task_id, session=session, log=log) + self._project_name = None self._storage_uri = None self._input_model = None self._output_model = None @@ -87,6 +88,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._parameters_allowed_types = ( six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None)) ) + self._app_server = None + self._files_server = None if not task_id: # generate a new task @@ -656,8 +659,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if self.project is None: return None + if self._project_name and self._project_name[0] == self.project: + return self._project_name[1] + res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False) - return res.response.project.name + self._project_name = (self.project, res.response.project.name) + return self._project_name[1] def get_tags(self): return self._get_task_property("tags") @@ -668,33 +675,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._edit(tags=self.data.tags) def _get_default_report_storage_uri(self): - app_host = self._get_app_server() - parsed = urlparse(app_host) - if parsed.port: - parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1)) - elif parsed.netloc.startswith('demoapp.'): - parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1)) - elif parsed.netloc.startswith('app.'): - parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1)) - else: - parsed = parsed._replace(netloc=parsed.netloc+':8081') - return urlunparse(parsed) + if not self._files_server: + self._files_server = Session.get_files_server_host() + return self._files_server @classmethod def _get_api_server(cls): return Session.get_api_server_host() - @classmethod - def _get_app_server(cls): - host = cls._get_api_server() - if '://demoapi.' in host: - return host.replace('://demoapi.', '://demoapp.', 1) - if '://api.' in host: - return host.replace('://api.', '://app.', 1) - - parsed = urlparse(host) - if parsed.port == 8008: - return host.replace(':8008', ':8080', 1) + def _get_app_server(self): + if not self._app_server: + self._app_server = Session.get_app_server_host() + return self._app_server def _edit(self, **kwargs): with self._edit_lock: From c80aae0e1ea96e4e83d608a08c2bc0699de9008d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 23:11:54 +0300 Subject: [PATCH 4/7] Fix support for sub-process (process pool) --- trains/backend_interface/base.py | 24 +-- trains/backend_interface/metrics/reporter.py | 30 +++- trains/binding/environ_bind.py | 42 +++++ trains/logger.py | 24 ++- trains/task.py | 154 +++++++++++++------ trains/utilities/async_manager.py | 21 ++- 6 files changed, 220 insertions(+), 75 deletions(-) diff --git a/trains/backend_interface/base.py b/trains/backend_interface/base.py index b2463c43..50649d4e 100644 --- a/trains/backend_interface/base.py +++ b/trains/backend_interface/base.py @@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface): if log: log.error(error_msg) - if res.meta.result_code <= 500: - # Proper backend error/bad status code - raise or return - if raise_on_errors: - raise SendError(res, error_msg) - return res - except requests.exceptions.BaseHTTPError as e: - log.error('failed sending %s: %s' % (str(req), str(e))) + res = None + log.error('Failed sending %s: %s' % (str(req), str(e))) + except Exception as e: + res = None + log.error('Failed sending %s: %s' % (str(req), str(e))) - # Infrastructure error - if log: - log.info('retrying request %s' % str(req)) + if res and res.meta.result_code <= 500: + # Proper backend error/bad status code - raise or return + if raise_on_errors: + raise SendError(res, error_msg) + return res + + # # Infrastructure error + # if log: + # log.info('retrying request %s' % str(req)) def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False): return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors, diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index e29074a6..c2fd08f3 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -1,8 +1,8 @@ import collections import json -import cv2 import six +from threading import Thread, Event from ..base import InterfaceBase from ..setupuploadmixin import SetupUploadMixin @@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan self._bucket_config = None self._storage_uri = None self._async_enable = async_enable + self._flush_frequency = 30.0 + self._exit_flag = False + self._flush_event = Event() + self._flush_event.clear() + self._thread = Thread(target=self._daemon) + self._thread.daemon = True + self._thread.start() def _set_storage_uri(self, value): value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x) @@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan def async_enable(self, value): self._async_enable = bool(value) + def _daemon(self): + while not self._exit_flag: + self._flush_event.wait(self._flush_frequency) + self._flush_event.clear() + self._write() + # wait for all reports + if self.get_num_results() > 0: + self.wait_for_results() + def _report(self, ev): self._events.append(ev) if len(self._events) >= self._flush_threshold: - self._write() + self.flush() def _write(self): if not self._events: @@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan """ Flush cached reports to backend. """ - self._write() - # wait for all reports - if self.get_num_results() > 0: - self.wait_for_results() + self._flush_event.set() + + def stop(self): + self._exit_flag = True + self._flush_event.set() + self._thread.join() def report_scalar(self, title, series, value, iter): """ diff --git a/trains/binding/environ_bind.py b/trains/binding/environ_bind.py index 1238a406..799125ec 100644 --- a/trains/binding/environ_bind.py +++ b/trains/binding/environ_bind.py @@ -1,5 +1,7 @@ import os +import six + from ..config import TASK_LOG_ENVIRONMENT, running_remotely @@ -34,3 +36,43 @@ class EnvironmentBind(object): if running_remotely(): # put back into os: os.environ.update(env_param) + + +class PatchOsFork(object): + _original_fork = None + + @classmethod + def patch_fork(cls): + # only once + if cls._original_fork: + return + if six.PY2: + cls._original_fork = staticmethod(os.fork) + else: + cls._original_fork = os.fork + os.fork = cls._patched_fork + + @staticmethod + def _patched_fork(*args, **kwargs): + ret = PatchOsFork._original_fork(*args, **kwargs) + # Make sure the new process stdout is logged + if not ret: + from ..task import Task + if Task.current_task() is not None: + # bind sub-process logger + task = Task.init() + task.get_logger().flush() + + # if we got here patch the os._exit of our instance to call us + def _at_exit_callback(*args, **kwargs): + # call at exit manually + # noinspection PyProtectedMember + task._at_exit() + # noinspection PyProtectedMember + return os._org_exit(*args, **kwargs) + + if not hasattr(os, '_org_exit'): + os._org_exit = os._exit + os._exit = _at_exit_callback + + return ret diff --git a/trains/logger.py b/trains/logger.py index 8d208512..b364b1ef 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -81,11 +81,14 @@ class Logger(object): self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100) # noinspection PyBroadException try: - Logger._stdout_original_write = sys.stdout.write + if Logger._stdout_original_write is None: + Logger._stdout_original_write = sys.stdout.write # this will only work in python 3, guard it with try/catch - sys.stdout._original_write = sys.stdout.write + if not hasattr(sys.stdout, '_original_write'): + sys.stdout._original_write = sys.stdout.write sys.stdout.write = stdout__patched__write__ - sys.stderr._original_write = sys.stderr.write + if not hasattr(sys.stderr, '_original_write'): + sys.stderr._original_write = sys.stderr.write sys.stderr.write = stderr__patched__write__ except Exception: pass @@ -113,6 +116,7 @@ class Logger(object): msg='Logger failed casting log level "%s" to integer' % str(level)) level = logging.INFO + # noinspection PyBroadException try: record = self._task.log.makeRecord( "console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None @@ -128,6 +132,7 @@ class Logger(object): if not omit_console: # if we are here and we grabbed the stdout, we need to print the real thing if DevWorker.report_stdout: + # noinspection PyBroadException try: # make sure we are writing to the original stdout Logger._stdout_original_write(str(msg)+'\n') @@ -637,11 +642,13 @@ class Logger(object): @classmethod def _remove_std_logger(self): if isinstance(sys.stdout, PrintPatchLogger): + # noinspection PyBroadException try: sys.stdout.connect(None) except Exception: pass if isinstance(sys.stderr, PrintPatchLogger): + # noinspection PyBroadException try: sys.stderr.connect(None) except Exception: @@ -711,7 +718,13 @@ class PrintPatchLogger(object): if cur_line: with PrintPatchLogger.recursion_protect_lock: - self._log.console(cur_line, level=self._log_level, omit_console=True) + # noinspection PyBroadException + try: + if self._log: + self._log.console(cur_line, level=self._log_level, omit_console=True) + except Exception: + # what can we do, nothing + pass else: if hasattr(self._terminal, '_original_write'): self._terminal._original_write(message) @@ -719,8 +732,7 @@ class PrintPatchLogger(object): self._terminal.write(message) def connect(self, logger): - if self._log: - self._log._flush_stdout_handler() + self._cur_line = '' self._log = logger def __getattr__(self, attr): diff --git a/trains/task.py b/trains/task.py index 1952a9a3..9939deaa 100644 --- a/trains/task.py +++ b/trains/task.py @@ -26,7 +26,7 @@ from .errors import UsageError from .logger import Logger from .model import InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters -from .binding.environ_bind import EnvironmentBind +from .binding.environ_bind import EnvironmentBind, PatchOsFork from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask @@ -66,6 +66,7 @@ class Task(_Task): __create_protection = object() __main_task = None __exit_hook = None + __forked_proc_main_pid = None __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) __store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False) __detect_repo_async = config.get('development.vcs_repo_detect_async', False) @@ -104,7 +105,6 @@ class Task(_Task): self._resource_monitor = None # register atexit, so that we mark the task as stopped self._at_exit_called = False - self.__register_at_exit(self._at_exit) @classmethod def current_task(cls): @@ -132,9 +132,10 @@ class Task(_Task): :param project_name: project to create the task in (if project doesn't exist, it will be created) :param task_name: task name to be created (in development mode, not when running remotely) :param task_type: task type to be created (in development mode, not when running remotely) - :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \ - if False every time we call the function we create a new task with the same name \ - Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \ + :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). + if False every time we call the function we create a new task with the same name + Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) + If reuse_last_task_id is of type string, it will assume this is the task_id to reuse! Note: A closed or published task will not be reused, and a new task will be created. :param output_uri: Default location for output models (currently support folder/S3/GS/ ). notice: sub-folders (task_id) is created in the destination folder for all outputs. @@ -166,12 +167,31 @@ class Task(_Task): ) if cls.__main_task is not None: + # if this is a subprocess, regardless of what the init was called for, + # we have to fix the main task hooks and stdout bindings + if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): + # make sure we only do it once per process + cls.__forked_proc_main_pid = os.getpid() + # make sure we do not wait for the repo detect thread + cls.__main_task._detect_repo_async_thread = None + # remove the logger from the previous process + logger = cls.__main_task.get_logger() + logger.set_flush_period(None) + # create a new logger (to catch stdout/err) + cls.__main_task._logger = None + cls.__main_task._reporter = None + cls.__main_task.get_logger() + # unregister signal hooks, they cause subprocess to hang + cls.__main_task.__register_at_exit(cls.__main_task._at_exit) + cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) + if not running_remotely(): verify_defaults_match() return cls.__main_task - # check that we are not a child process, in that case do nothing + # check that we are not a child process, in that case do nothing. + # we should not get here unless this is Windows platform, all others support fork if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): class _TaskStub(object): def __call__(self, *args, **kwargs): @@ -212,9 +232,10 @@ class Task(_Task): raise else: Task.__main_task = task - # Patch argparse to be aware of the current task - argparser_update_currenttask(Task.__main_task) - EnvironmentBind.update_current_task(Task.__main_task) + # register the main task for at exit hooks (there should only be one) + task.__register_at_exit(task._at_exit) + # patch OS forking + PatchOsFork.patch_fork() if auto_connect_frameworks: PatchedMatplotlib.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task) @@ -227,21 +248,19 @@ class Task(_Task): if auto_resource_monitoring: task._resource_monitor = ResourceMonitor(task) task._resource_monitor.start() - # Check if parse args already called. If so, sync task parameters with parser - if argparser_parseargs_called(): - parser, parsed_args = get_argparser_last_args() - task._connect_argparse(parser=parser, parsed_args=parsed_args) # make sure all random generators are initialized with new seed make_deterministic(task.get_random_seed()) if auto_connect_arg_parser: + EnvironmentBind.update_current_task(Task.__main_task) + # Patch ArgParser to be aware of the current task argparser_update_currenttask(Task.__main_task) # Check if parse args already called. If so, sync task parameters with parser if argparser_parseargs_called(): parser, parsed_args = get_argparser_last_args() - task._connect_argparse(parser, parsed_args=parsed_args) + task._connect_argparse(parser=parser, parsed_args=parsed_args) # Make sure we start the logger, it will patch the main logging object and pipe all output # if we are running locally and using development mode worker, we will pipe all stdout to logger. @@ -339,7 +358,9 @@ class Task(_Task): in_dev_mode = not running_remotely() if in_dev_mode: - if not reuse_last_task_id or not cls.__task_is_relevant(default_task): + if isinstance(reuse_last_task_id, str) and reuse_last_task_id: + default_task_id = reuse_last_task_id + elif not reuse_last_task_id or not cls.__task_is_relevant(default_task): default_task_id = None closed_old_task = cls.__close_timed_out_task(default_task) else: @@ -600,6 +621,9 @@ class Task(_Task): """ self._at_exit() self._at_exit_called = False + # unregister atexit callbacks and signal hooks, if we are the main task + if self.is_main_task(): + self.__register_at_exit(None) def is_current_task(self): """ @@ -914,9 +938,12 @@ class Task(_Task): Will happen automatically once we exit code, i.e. atexit :return: """ + # protect sub-process at_exit if self._at_exit_called: return + is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid() + # noinspection PyBroadException try: # from here do not get into watch dog @@ -948,28 +975,32 @@ class Task(_Task): # from here, do not send log in background thread if wait_for_uploads: self.flush(wait_for_uploads=True) + # wait until the reporter flush everything + self.reporter.stop() if print_done_waiting: self.log.info('Finished uploading') else: self._logger._flush_stdout_handler() - # from here, do not check worker status - if self._dev_worker: - self._dev_worker.unregister() + if not is_sub_process: + # from here, do not check worker status + if self._dev_worker: + self._dev_worker.unregister() - # change task status - if not task_status: - pass - elif task_status[0] == 'failed': - self.mark_failed(status_reason=task_status[1]) - elif task_status[0] == 'completed': - self.completed() - elif task_status[0] == 'stopped': - self.stopped() + # change task status + if not task_status: + pass + elif task_status[0] == 'failed': + self.mark_failed(status_reason=task_status[1]) + elif task_status[0] == 'completed': + self.completed() + elif task_status[0] == 'stopped': + self.stopped() # stop resource monitoring if self._resource_monitor: self._resource_monitor.stop() + self._logger.set_flush_period(None) # this is so in theory we can close a main task and start a new one Task.__main_task = None @@ -978,7 +1009,7 @@ class Task(_Task): pass @classmethod - def __register_at_exit(cls, exit_callback): + def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False): class ExitHooks(object): _orig_exit = None _orig_exc_handler = None @@ -1000,7 +1031,21 @@ class Task(_Task): except Exception: pass self._exit_callback = callback - atexit.register(self._exit_callback) + if callback: + self.hook() + else: + # un register int hook + print('removing int hook', self._orig_exc_handler) + if self._orig_exc_handler: + sys.excepthook = self._orig_exc_handler + self._orig_exc_handler = None + for s in self._org_handlers: + # noinspection PyBroadException + try: + signal.signal(s, self._org_handlers[s]) + except Exception: + pass + self._org_handlers = {} def hook(self): if self._orig_exit is None: @@ -1009,20 +1054,23 @@ class Task(_Task): if self._orig_exc_handler is None: self._orig_exc_handler = sys.excepthook sys.excepthook = self.exc_handler - atexit.register(self._exit_callback) - if sys.platform == 'win32': - catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, - signal.SIGILL, signal.SIGFPE] - else: - catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, - signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] - for s in catch_signals: - # noinspection PyBroadException - try: - self._org_handlers[s] = signal.getsignal(s) - signal.signal(s, self.signal_handler) - except Exception: - pass + if self._exit_callback: + atexit.register(self._exit_callback) + + if self._org_handlers: + if sys.platform == 'win32': + catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, + signal.SIGILL, signal.SIGFPE] + else: + catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, + signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] + for s in catch_signals: + # noinspection PyBroadException + try: + self._org_handlers[s] = signal.getsignal(s) + signal.signal(s, self.signal_handler) + except Exception: + pass def exit(self, code=0): self.exit_code = code @@ -1077,6 +1125,22 @@ class Task(_Task): # return handler result return org_handler + # we only remove the signals since this will hang subprocesses + if only_remove_signal_and_exception_hooks: + if not cls.__exit_hook: + return + if cls.__exit_hook._orig_exc_handler: + sys.excepthook = cls.__exit_hook._orig_exc_handler + cls.__exit_hook._orig_exc_handler = None + for s in cls.__exit_hook._org_handlers: + # noinspection PyBroadException + try: + signal.signal(s, cls.__exit_hook._org_handlers[s]) + except Exception: + pass + cls.__exit_hook._org_handlers = {} + return + if cls.__exit_hook is None: # noinspection PyBroadException try: @@ -1084,13 +1148,13 @@ class Task(_Task): cls.__exit_hook.hook() except Exception: cls.__exit_hook = None - elif cls.__main_task is None: + else: cls.__exit_hook.update_callback(exit_callback) @classmethod def __get_task(cls, task_id=None, project_name=None, task_name=None): if task_id: - return cls(private=cls.__create_protection, task_id=task_id) + return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) res = cls._send( cls._get_default_session(), diff --git a/trains/utilities/async_manager.py b/trains/utilities/async_manager.py index 45732d67..3c19441a 100644 --- a/trains/utilities/async_manager.py +++ b/trains/utilities/async_manager.py @@ -1,3 +1,4 @@ +import os import time from threading import Lock @@ -6,7 +7,8 @@ import six class AsyncManagerMixin(object): _async_results_lock = Lock() - _async_results = [] + # per pid (process) list of async jobs (support for sub-processes forking) + _async_results = {} @classmethod def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None): @@ -14,8 +16,9 @@ class AsyncManagerMixin(object): try: cls._async_results_lock.acquire() # discard completed results - cls._async_results = [r for r in cls._async_results if not r.ready()] - num_results = len(cls._async_results) + pid = os.getpid() + cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()] + num_results = len(cls._async_results[pid]) if wait_on_max_results is not None and num_results >= wait_on_max_results: # At least max_results results are still pending, wait if wait_cb: @@ -25,7 +28,7 @@ class AsyncManagerMixin(object): continue # add result if result and not result.ready(): - cls._async_results.append(result) + cls._async_results[pid] = cls._async_results.get(pid, []).append(result) break finally: cls._async_results_lock.release() @@ -34,7 +37,8 @@ class AsyncManagerMixin(object): def wait_for_results(cls, timeout=None, max_num_uploads=None): remaining = timeout count = 0 - for r in cls._async_results: + pid = os.getpid() + for r in cls._async_results.get(pid, []): if r.ready(): continue t = time.time() @@ -48,13 +52,14 @@ class AsyncManagerMixin(object): if max_num_uploads is not None and max_num_uploads - count <= 0: break if timeout is not None: - remaining = max(0, remaining - max(0, time.time() - t)) + remaining = max(0., remaining - max(0., time.time() - t)) if not remaining: break @classmethod def get_num_results(cls): - if cls._async_results is not None: - return len([r for r in cls._async_results if not r.ready()]) + pid = os.getpid() + if cls._async_results.get(pid, []): + return len([r for r in cls._async_results.get(pid, []) if not r.ready()]) else: return 0 From c93d0301099792a4bf7b6ab84a3693adff13a160 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 23:55:14 +0300 Subject: [PATCH 5/7] Fix support for sub-process (process pool) --- trains/utilities/async_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trains/utilities/async_manager.py b/trains/utilities/async_manager.py index 3c19441a..fd57e11d 100644 --- a/trains/utilities/async_manager.py +++ b/trains/utilities/async_manager.py @@ -28,7 +28,9 @@ class AsyncManagerMixin(object): continue # add result if result and not result.ready(): - cls._async_results[pid] = cls._async_results.get(pid, []).append(result) + if not cls._async_results.get(pid): + cls._async_results[pid] = [] + cls._async_results[pid].append(result) break finally: cls._async_results_lock.release() From b0a759c4548b136678fb46f7f9a1ca32175a525e Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 23:55:34 +0300 Subject: [PATCH 6/7] Add Seaborn to Matplotlib example --- examples/matplotlib_example.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/matplotlib_example.py b/examples/matplotlib_example.py index f918bed7..2398b250 100644 --- a/examples/matplotlib_example.py +++ b/examples/matplotlib_example.py @@ -1,7 +1,8 @@ -# TRAINS - Example of Matplotlib integration and reporting +# TRAINS - Example of Matplotlib and Seaborn integration and reporting # import numpy as np import matplotlib.pyplot as plt +import seaborn as sns from trains import Task @@ -33,4 +34,13 @@ plt.imshow(m) plt.title('Image Title') plt.show() -print('This is a Matplotlib example') +sns.set(style="darkgrid") +# Load an example dataset with long-form data +fmri = sns.load_dataset("fmri") +# Plot the responses for different events and regions +sns.lineplot(x="timepoint", y="signal", + hue="region", style="event", + data=fmri) +plt.show() + +print('This is a Matplotlib & Seaborn example') From f19401da39711d9fa619c66ce7ed49ad0ffe03ca Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 20 Jul 2019 23:55:46 +0300 Subject: [PATCH 7/7] version bump --- trains/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trains/version.py b/trains/version.py index 85b551d3..07764289 100644 --- a/trains/version.py +++ b/trains/version.py @@ -1 +1 @@ -__version__ = '0.10.2' +__version__ = '0.10.3rc1'