1
0
mirror of https://github.com/clearml/clearml synced 2025-05-08 06:44:26 +00:00

Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Erez Schanider 2019-07-21 08:48:35 +03:00
commit 94642af4a1
16 changed files with 522 additions and 256 deletions
docs
examples
trains
backend_api
backend_interface
binding
config/default
logger.pytask.py
utilities
version.py

View File

@ -1,7 +1,13 @@
# TRAINS SDK configuration file # TRAINS SDK configuration file
api { api {
# Notice: 'host' is the api server (default port 8008), not the web server. # web_server on port 8080
host: http://localhost:8008 web_server: "http://localhost:8080"
# Notice: 'api_server' is the api server (default port 8008), not the web server.
api_server: "http://localhost:8008"
# file server onport 8081
files_server: "http://localhost:8081"
# Credentials are generated in the webapp, http://localhost:8080/admin # Credentials are generated in the webapp, http://localhost:8080/admin
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"} credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}

View File

@ -1,7 +1,8 @@
# TRAINS - Example of Matplotlib integration and reporting # TRAINS - Example of Matplotlib and Seaborn integration and reporting
# #
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
from trains import Task from trains import Task
@ -33,4 +34,13 @@ plt.imshow(m)
plt.title('Image Title') plt.title('Image Title')
plt.show() plt.show()
print('This is a Matplotlib example') sns.set(style="darkgrid")
# Load an example dataset with long-form data
fmri = sns.load_dataset("fmri")
# Plot the responses for different events and regions
sns.lineplot(x="timepoint", y="signal",
hue="region", style="event",
data=fmri)
plt.show()
print('This is a Matplotlib & Seaborn example')

View File

@ -1,7 +1,11 @@
{ {
version: 1.5 version: 1.5
# default https://demoapi.trainsai.io host # default api_server: https://demoapi.trainsai.io
host: "" api_server: ""
# default web_server: https://demoapp.trainsai.io
web_server: ""
# default files_server: https://demofiles.trainsai.io
files_server: ""
# verify host ssl certificate, set to False only if you have a very good reason # verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True verify_certificate: True

View File

@ -2,6 +2,8 @@ from ...backend_config import EnvEntry
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST") ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "ALG_WEB_HOST")
ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "ALG_FILES_HOST")
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY") ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY") ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False) ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)

View File

@ -2,6 +2,7 @@ import json as json_lib
import sys import sys
import types import types
from socket import gethostname from socket import gethostname
from six.moves.urllib.parse import urlparse, urlunparse
import jwt import jwt
import requests import requests
@ -10,11 +11,11 @@ from pyhocon import ConfigTree
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
from .callresult import CallResult from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
from .request import Request, BatchRequest from .request import Request, BatchRequest
from .token_manager import TokenManager from .token_manager import TokenManager
from ..config import load from ..config import load
from ..utils import get_http_session_with_retry from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ..version import __version__ from ..version import __version__
@ -32,11 +33,13 @@ class Session(TokenManager):
_async_status_code = 202 _async_status_code = 202
_session_requests = 0 _session_requests = 0
_session_initial_timeout = (1.0, 10) _session_initial_timeout = (3.0, 10.)
_session_timeout = (5.0, None) _session_timeout = (5.0, 300.)
api_version = '2.1' api_version = '2.1'
default_host = "https://demoapi.trainsai.io" default_host = "https://demoapi.trainsai.io"
default_web = "https://demoapp.trainsai.io"
default_files = "https://demofiles.trainsai.io"
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
@ -97,7 +100,7 @@ class Session(TokenManager):
self._logger = logger self._logger = logger
self.__access_key = api_key or ENV_ACCESS_KEY.get( self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=(self.config.get("api.credentials.access_key") or self.default_key) default=(self.config.get("api.credentials.access_key", None) or self.default_key)
) )
if not self.access_key: if not self.access_key:
raise ValueError( raise ValueError(
@ -105,7 +108,7 @@ class Session(TokenManager):
) )
self.__secret_key = secret_key or ENV_SECRET_KEY.get( self.__secret_key = secret_key or ENV_SECRET_KEY.get(
default=(self.config.get("api.credentials.secret_key") or self.default_secret) default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
) )
if not self.secret_key: if not self.secret_key:
raise ValueError( raise ValueError(
@ -125,7 +128,7 @@ class Session(TokenManager):
self.__worker = worker or gethostname() self.__worker = worker or gethostname()
self.__max_req_size = self.config.get("api.http.max_req_size") self.__max_req_size = self.config.get("api.http.max_req_size", None)
if not self.__max_req_size: if not self.__max_req_size:
raise ValueError("missing max request size") raise ValueError("missing max request size")
@ -140,6 +143,11 @@ class Session(TokenManager):
except (jwt.DecodeError, ValueError): except (jwt.DecodeError, ValueError):
pass pass
# now setup the session reporting, so one consecutive retries will show warning
# we do that here, so if we have problems authenticating, we see them immediately
# notice: this is across the board warning omission
urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3)
def _send_request( def _send_request(
self, self,
service, service,
@ -394,7 +402,65 @@ class Session(TokenManager):
if not config: if not config:
from ...config import config_obj from ...config import config_obj
config = config_obj config = config_obj
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host)) return ENV_HOST.get(default=(config.get("api.api_server", None) or
config.get("api.host", None) or cls.default_host))
@classmethod
def get_app_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
# get from config/environment
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
if web_host:
return web_host
# return default
host = cls.get_api_server_host(config)
if host == cls.default_host:
return cls.default_web
# compose ourselves
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host:
return host.replace('://api.', '://app.', 1)
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080', 1)
raise ValueError('Could not detect TRAINS web application server')
@classmethod
def get_files_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
# get from config/environment
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
if files_host:
return files_host
# return default
host = cls.get_api_server_host(config)
if host == cls.default_host:
return cls.default_files
# compose ourselves
app_host = cls.get_app_server_host(config)
parsed = urlparse(app_host)
if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else:
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
return urlunparse(parsed)
def _do_refresh_token(self, old_token, exp=None): def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation. """ TokenManager abstract method implementation.

View File

@ -27,6 +27,29 @@ def get_config():
return config_obj return config_obj
def urllib_log_warning_setup(total_retries=10, display_warning_after=5):
class RetryFilter(logging.Filter):
last_instance = None
def __init__(self, total, warning_after=5):
super(RetryFilter, self).__init__()
self.total = total
self.display_warning_after = warning_after
self.last_instance = self
def filter(self, record):
if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry):
retry_left = self.total - record.args[0].total
return retry_left >= self.display_warning_after
return True
urllib3_log = logging.getLogger('urllib3.connectionpool')
if urllib3_log:
urllib3_log.removeFilter(RetryFilter.last_instance)
urllib3_log.addFilter(RetryFilter(total_retries, display_warning_after))
class TLSv1HTTPAdapter(HTTPAdapter): class TLSv1HTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
self.poolmanager = PoolManager(num_pools=connections, self.poolmanager = PoolManager(num_pools=connections,

View File

@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface):
if log: if log:
log.error(error_msg) log.error(error_msg)
if res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
if raise_on_errors:
raise SendError(res, error_msg)
return res
except requests.exceptions.BaseHTTPError as e: except requests.exceptions.BaseHTTPError as e:
log.error('failed sending %s: %s' % (str(req), str(e))) res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
except Exception as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
# Infrastructure error if res and res.meta.result_code <= 500:
if log: # Proper backend error/bad status code - raise or return
log.info('retrying request %s' % str(req)) if raise_on_errors:
raise SendError(res, error_msg)
return res
# # Infrastructure error
# if log:
# log.info('retrying request %s' % str(req))
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False): def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors, return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,

View File

@ -1,8 +1,8 @@
import collections import collections
import json import json
import cv2
import six import six
from threading import Thread, Event
from ..base import InterfaceBase from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin from ..setupuploadmixin import SetupUploadMixin
@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._bucket_config = None self._bucket_config = None
self._storage_uri = None self._storage_uri = None
self._async_enable = async_enable self._async_enable = async_enable
self._flush_frequency = 30.0
self._exit_flag = False
self._flush_event = Event()
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _set_storage_uri(self, value): def _set_storage_uri(self, value):
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x) value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
def async_enable(self, value): def async_enable(self, value):
self._async_enable = bool(value) self._async_enable = bool(value)
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency)
self._flush_event.clear()
self._write()
# wait for all reports
if self.get_num_results() > 0:
self.wait_for_results()
def _report(self, ev): def _report(self, ev):
self._events.append(ev) self._events.append(ev)
if len(self._events) >= self._flush_threshold: if len(self._events) >= self._flush_threshold:
self._write() self.flush()
def _write(self): def _write(self):
if not self._events: if not self._events:
@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
""" """
Flush cached reports to backend. Flush cached reports to backend.
""" """
self._write() self._flush_event.set()
# wait for all reports
if self.get_num_results() > 0: def stop(self):
self.wait_for_results() self._exit_flag = True
self._flush_event.set()
self._thread.join()
def report_scalar(self, title, series, value, iter): def report_scalar(self, title, series, value, iter):
""" """

View File

@ -77,6 +77,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
self._edit_lock = RLock() self._edit_lock = RLock()
super(Task, self).__init__(id=task_id, session=session, log=log) super(Task, self).__init__(id=task_id, session=session, log=log)
self._project_name = None
self._storage_uri = None self._storage_uri = None
self._input_model = None self._input_model = None
self._output_model = None self._output_model = None
@ -87,6 +88,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._parameters_allowed_types = ( self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None)) six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
) )
self._app_server = None
self._files_server = None
if not task_id: if not task_id:
# generate a new task # generate a new task
@ -656,8 +659,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if self.project is None: if self.project is None:
return None return None
if self._project_name and self._project_name[0] == self.project:
return self._project_name[1]
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False) res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
return res.response.project.name self._project_name = (self.project, res.response.project.name)
return self._project_name[1]
def get_tags(self): def get_tags(self):
return self._get_task_property("tags") return self._get_task_property("tags")
@ -668,33 +675,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(tags=self.data.tags) self._edit(tags=self.data.tags)
def _get_default_report_storage_uri(self): def _get_default_report_storage_uri(self):
app_host = self._get_app_server() if not self._files_server:
parsed = urlparse(app_host) self._files_server = Session.get_files_server_host()
if parsed.port: return self._files_server
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else:
parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed)
@classmethod @classmethod
def _get_api_server(cls): def _get_api_server(cls):
return Session.get_api_server_host() return Session.get_api_server_host()
@classmethod def _get_app_server(self):
def _get_app_server(cls): if not self._app_server:
host = cls._get_api_server() self._app_server = Session.get_app_server_host()
if '://demoapi.' in host: return self._app_server
return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host:
return host.replace('://api.', '://app.', 1)
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080', 1)
def _edit(self, **kwargs): def _edit(self, **kwargs):
with self._edit_lock: with self._edit_lock:

View File

@ -1,5 +1,7 @@
import os import os
import six
from ..config import TASK_LOG_ENVIRONMENT, running_remotely from ..config import TASK_LOG_ENVIRONMENT, running_remotely
@ -34,3 +36,43 @@ class EnvironmentBind(object):
if running_remotely(): if running_remotely():
# put back into os: # put back into os:
os.environ.update(env_param) os.environ.update(env_param)
class PatchOsFork(object):
_original_fork = None
@classmethod
def patch_fork(cls):
# only once
if cls._original_fork:
return
if six.PY2:
cls._original_fork = staticmethod(os.fork)
else:
cls._original_fork = os.fork
os.fork = cls._patched_fork
@staticmethod
def _patched_fork(*args, **kwargs):
ret = PatchOsFork._original_fork(*args, **kwargs)
# Make sure the new process stdout is logged
if not ret:
from ..task import Task
if Task.current_task() is not None:
# bind sub-process logger
task = Task.init()
task.get_logger().flush()
# if we got here patch the os._exit of our instance to call us
def _at_exit_callback(*args, **kwargs):
# call at exit manually
# noinspection PyProtectedMember
task._at_exit()
# noinspection PyProtectedMember
return os._org_exit(*args, **kwargs)
if not hasattr(os, '_org_exit'):
os._org_exit = os._exit
os._exit = _at_exit_callback
return ret

View File

@ -1,4 +1,5 @@
import base64 import base64
import os
import sys import sys
import threading import threading
from collections import defaultdict from collections import defaultdict
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into TF SummaryWriter implementation that converts the tensorboard's summary into
Trains events and reports the events (metrics) for an Trains task (logger). Trains events and reports the events (metrics) for an Trains task (logger).
""" """
_add_lock = threading.Lock() _add_lock = threading.RLock()
_series_name_lookup = {} _series_name_lookup = {}
# store all the created tensorboard writers in the system
# this allows us to as weather a certain tile/series already exist on some EventWriter
# and if it does, then we add to the series name the last token from the logdir
# (so we can differentiate between the two)
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
_title_series_writers_lookup = {}
_event_writers_id_to_logdir = {}
@property @property
def variants(self): def variants(self):
return self._variants return self._variants
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
def prepare_report(self): def prepare_report(self):
return self.variants.copy() return self.variants.copy()
@staticmethod def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'): logdir_header='series'):
""" """
Split a tf.summary tag line to variant and metric. Split a tf.summary tag line to variant and metric.
Variant is the first part of the split tag, metric is the second. Variant is the first part of the split tag, metric is the second.
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
:param str split_char: a character to split the tag on :param str split_char: a character to split the tag on
:param str join_char: a character to join the the splits :param str join_char: a character to join the the splits
:param str default_title: variant to use in case no variant can be inferred automatically :param str default_title: variant to use in case no variant can be inferred automatically
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
if 'title_last' then title=header title, if 'title' then title=title header
:return: (str, str) variant and metric :return: (str, str) variant and metric
""" """
splitted_tag = tag.split(split_char) splitted_tag = tag.split(split_char)
series = join_char.join(splitted_tag[-num_split_parts:]) series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
# check if we already decided that we need to change the title/series
graph_id = hash((title, series))
if graph_id in self._graph_name_lookup:
return self._graph_name_lookup[graph_id]
# check if someone other than us used this combination
with self._add_lock:
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
if not event_writer_id:
# put us there
self._title_series_writers_lookup[graph_id] = self._id
elif event_writer_id != self._id:
# if there is someone else, change our series name and store us
org_series = series
org_title = title
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
split_logddir = self._logdir.split(os.path.sep)
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
header = '/'.join(s for s in split_logddir if s in unique_logdir)
if logdir_header == 'series_last':
series = header + ': ' + series
elif logdir_header == 'series':
series = series + ' :' + header
elif logdir_header == 'title':
title = title + ' ' + header
else: # logdir_header == 'title_last':
title = header + ' ' + title
graph_id = hash((title, series))
# check if for some reason the new series is already occupied
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
if new_event_writer_id is not None and new_event_writer_id != self._id:
# well that's about it, nothing else we could do
if logdir_header == 'series_last':
series = str(self._logdir) + ': ' + org_series
elif logdir_header == 'series':
series = org_series + ' :' + str(self._logdir)
elif logdir_header == 'title':
title = org_title + ' ' + str(self._logdir)
else: # logdir_header == 'title_last':
title = str(self._logdir) + ' ' + org_title
graph_id = hash((title, series))
self._title_series_writers_lookup[graph_id] = self._id
# store for next time
self._graph_name_lookup[graph_id] = (title, series)
return title, series return title, series
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10, def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
histogram_granularity=50, max_keep_images=None): histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
""" """
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
""" """
# We are the events_writer, so that's what we'll pass # We are the events_writer, so that's what we'll pass
IsTensorboardInit.set_tensorboard_used() IsTensorboardInit.set_tensorboard_used()
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
self._id = hash(self._logdir)
self._event_writers_id_to_logdir[self._id] = self._logdir
self.max_keep_images = max_keep_images self.max_keep_images = max_keep_images
self.report_freq = report_freq self.report_freq = report_freq
self.image_report_freq = image_report_freq if image_report_freq else report_freq self.image_report_freq = image_report_freq if image_report_freq else report_freq
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
self._hist_report_cache = {} self._hist_report_cache = {}
self._hist_x_granularity = 50 self._hist_x_granularity = 50
self._max_step = 0 self._max_step = 0
self._graph_name_lookup = {}
def _decode_image(self, img_str, width, height, color_channels): def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException # noinspection PyBroadException
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
if img_data_np is None: if img_data_np is None:
return return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images') title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
if img_data_np.dtype != np.uint8: if img_data_np.dtype != np.uint8:
# assume scale 0-1 # assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8) img_data_np = (img_data_np * 255).astype(np.uint8)
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix) return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
def _add_scalar(self, tag, step, scalar_data): def _add_scalar(self, tag, step, scalar_data):
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars') title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
# update scalar cache # update scalar cache
num, value = self._scalar_report_cache.get((title, series), (0, 0)) num, value = self._scalar_report_cache.get((title, series), (0, 0))
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
# Y-axis (rows) is iteration (from 0 to current Step) # Y-axis (rows) is iteration (from 0 to current Step)
# X-axis averaged bins (conformed sample 'bucketLimit') # X-axis averaged bins (conformed sample 'bucketLimit')
# Z-axis actual value (interpolated 'bucket') # Z-axis actual value (interpolated 'bucket')
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms') title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
logdir_header='series')
# get histograms from cache # get histograms from cache
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None)) hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.trains: if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict) logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
self.trains.add_event(*args, **kwargs) self.trains.add_event(*args, **kwargs)
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task: if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.trains: if not self.trains:
try:
logdir = self.get_logdir()
except Exception:
logdir = None
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
**PatchSummaryToEventTransformer.defaults_dict) logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
self.trains.add_event(*args, **kwargs) self.trains.add_event(*args, **kwargs)
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
# patch the events writer field, and add a double Event Logger (Trains and original) # patch the events writer field, and add a double Event Logger (Trains and original)
base_eventwriter = __dict__['event_writer'] base_eventwriter = __dict__['event_writer']
try:
logdir = base_eventwriter.get_logdir()
except Exception:
logdir = None
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict) trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
logdir=logdir, **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list # order is important, the return value of ProxyEventsWriter is the last object in the list
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter]) __dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
getLogger(TrainsFrameworkAdapter).warning(str(ex)) getLogger(TrainsFrameworkAdapter).warning(str(ex))
@staticmethod @staticmethod
def _get_event_writer(): def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task: if not PatchTensorFlowEager.__main_task:
return None return None
if PatchTensorFlowEager.__trains_event_writer is None: if PatchTensorFlowEager.__trains_event_writer is None:
try:
logdir = writer.get_logdir()
except Exception:
logdir = None
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter( PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict) logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer return PatchTensorFlowEager.__trains_event_writer
@staticmethod @staticmethod
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer() event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs): def _write_hist_summary(writer, step, tag, values, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer() event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
event_writer = PatchTensorFlowEager._get_event_writer() event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer: if event_writer:
try: try:
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
@ -1350,93 +1431,3 @@ class PatchTensorflowModelIO(object):
pass pass
return model return model
class PatchPyTorchModelIO(object):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._patch_model_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model

View File

@ -11,19 +11,20 @@ from trains.config import config_obj
description = """ description = """
Please create new credentials using the web app: {}/admin Please create new credentials using the web app: {}/profile
In the Admin page, press "Create new credentials", then press "Copy to clipboard" In the Admin page, press "Create new credentials", then press "Copy to clipboard"
Paste credentials here: """ Paste credentials here: """
try: try:
def_host = ENV_HOST.get(default=config_obj.get("api.host")) def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080'
except Exception: except Exception:
def_host = 'http://localhost:8080' def_host = 'http://localhost:8080'
host_description = """ host_description = """
Editing configuration file: {CONFIG_FILE} Editing configuration file: {CONFIG_FILE}
Enter the url of the trains-server's api service, for example: http://localhost:8008 : """.format( Enter the url of the trains-server's Web service, for example: {HOST}
""".format(
CONFIG_FILE=LOCAL_CONFIG_FILES[0], CONFIG_FILE=LOCAL_CONFIG_FILES[0],
HOST=def_host, HOST=def_host,
) )
@ -37,64 +38,60 @@ def main():
print('Leaving setup, feel free to edit the configuration file.') print('Leaving setup, feel free to edit the configuration file.')
return return
print(host_description, end='') print(host_description)
parsed_host = None web_host = input_url('Web Application Host', '')
while not parsed_host: parsed_host = verify_url(web_host)
parse_input = input()
if not parse_input:
parse_input = def_host
# noinspection PyBroadException
try:
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
parse_input = 'http://'+parse_input
parsed_host = urlparse(parse_input)
if parsed_host.scheme not in ('http', 'https'):
parsed_host = None
except Exception:
parsed_host = None
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
if parsed_host.port == 8080: if parsed_host.port == 8008:
# this is a docker 8080 is the web address, we need the api address, it is 8008 print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008') api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
elif parsed_host.port == 8080:
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'): elif parsed_host.netloc.startswith('demoapp.'):
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
parsed_host.netloc))
# this is our demo server # this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('app.'): elif parsed_host.netloc.startswith('app.'):
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
parsed_host.netloc))
# this is our application server # this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.port == 8008: files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapi.'): elif parsed_host.netloc.startswith('demoapi.'):
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('api.'): elif parsed_host.netloc.startswith('api.'):
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
parsed_host.netloc))
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
else: else:
api_host = None api_host = ''
web_host = None web_host = ''
files_host = ''
if not parsed_host.port: if not parsed_host.port:
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='') print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
replace_port = input().lower() replace_port = input().lower()
if not replace_port or replace_port == 'y' or replace_port == 'yes': if not replace_port or replace_port == 'y' or replace_port == 'yes':
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
if not api_host: if not api_host:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
if not web_host:
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
print('Host configured to: {}'.format(api_host)) api_host = input_url('API Host', api_host)
files_host = input_url('File Store Host', files_host)
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
api_host, web_host, files_host))
print(description.format(web_host), end='') print(description.format(web_host), end='')
parse_input = input() parse_input = input()
@ -133,11 +130,14 @@ def main():
header = '# TRAINS SDK configuration file\n' \ header = '# TRAINS SDK configuration file\n' \
'api {\n' \ 'api {\n' \
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \ ' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
' host: %s\n' \ ' api_server: %s\n' \
' # Credentials are generated in the webapp, %s/admin\n' \ ' web_server: %s\n' \
' files_server: %s\n' \
' # Credentials are generated in the webapp, %s/profile\n' \
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \ ' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
'}\n' \ '}\n' \
'sdk ' % (api_host, web_host, credentials['access_key'], credentials['secret_key']) 'sdk ' % (api_host, web_host, files_host,
web_host, credentials['access_key'], credentials['secret_key'])
f.write(header) f.write(header)
f.write(default_sdk) f.write(default_sdk)
except Exception: except Exception:
@ -148,5 +148,30 @@ def main():
print('TRAINS setup completed successfully.') print('TRAINS setup completed successfully.')
def input_url(host_type, host=None):
while True:
print('{} configured to: [{}] '.format(host_type, host), end='')
parse_input = input()
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
break
if parse_input and verify_url(parse_input):
host = parse_input
break
return host
def verify_url(parse_input):
try:
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
parse_input = 'http://' + parse_input
parsed_host = urlparse(parse_input)
if parsed_host.scheme not in ('http', 'https'):
parsed_host = None
except Exception:
parsed_host = None
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
return parsed_host
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -81,11 +81,14 @@ class Logger(object):
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100) self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
Logger._stdout_original_write = sys.stdout.write if Logger._stdout_original_write is None:
Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, guard it with try/catch # this will only work in python 3, guard it with try/catch
sys.stdout._original_write = sys.stdout.write if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__ sys.stdout.write = stdout__patched__write__
sys.stderr._original_write = sys.stderr.write if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = stderr__patched__write__ sys.stderr.write = stderr__patched__write__
except Exception: except Exception:
pass pass
@ -113,6 +116,7 @@ class Logger(object):
msg='Logger failed casting log level "%s" to integer' % str(level)) msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO level = logging.INFO
# noinspection PyBroadException
try: try:
record = self._task.log.makeRecord( record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None "console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
@ -128,6 +132,7 @@ class Logger(object):
if not omit_console: if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing # if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout: if DevWorker.report_stdout:
# noinspection PyBroadException
try: try:
# make sure we are writing to the original stdout # make sure we are writing to the original stdout
Logger._stdout_original_write(str(msg)+'\n') Logger._stdout_original_write(str(msg)+'\n')
@ -637,11 +642,13 @@ class Logger(object):
@classmethod @classmethod
def _remove_std_logger(self): def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger): if isinstance(sys.stdout, PrintPatchLogger):
# noinspection PyBroadException
try: try:
sys.stdout.connect(None) sys.stdout.connect(None)
except Exception: except Exception:
pass pass
if isinstance(sys.stderr, PrintPatchLogger): if isinstance(sys.stderr, PrintPatchLogger):
# noinspection PyBroadException
try: try:
sys.stderr.connect(None) sys.stderr.connect(None)
except Exception: except Exception:
@ -711,7 +718,13 @@ class PrintPatchLogger(object):
if cur_line: if cur_line:
with PrintPatchLogger.recursion_protect_lock: with PrintPatchLogger.recursion_protect_lock:
self._log.console(cur_line, level=self._log_level, omit_console=True) # noinspection PyBroadException
try:
if self._log:
self._log.console(cur_line, level=self._log_level, omit_console=True)
except Exception:
# what can we do, nothing
pass
else: else:
if hasattr(self._terminal, '_original_write'): if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message) self._terminal._original_write(message)
@ -719,8 +732,7 @@ class PrintPatchLogger(object):
self._terminal.write(message) self._terminal.write(message)
def connect(self, logger): def connect(self, logger):
if self._log: self._cur_line = ''
self._log._flush_stdout_handler()
self._log = logger self._log = logger
def __getattr__(self, attr): def __getattr__(self, attr):

View File

@ -26,7 +26,7 @@ from .errors import UsageError
from .logger import Logger from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters from .task_parameters import TaskParameters
from .binding.environ_bind import EnvironmentBind from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.absl_bind import PatchAbsl from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask argparser_update_currenttask
@ -66,6 +66,7 @@ class Task(_Task):
__create_protection = object() __create_protection = object()
__main_task = None __main_task = None
__exit_hook = None __exit_hook = None
__forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False) __store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
__detect_repo_async = config.get('development.vcs_repo_detect_async', False) __detect_repo_async = config.get('development.vcs_repo_detect_async', False)
@ -104,7 +105,6 @@ class Task(_Task):
self._resource_monitor = None self._resource_monitor = None
# register atexit, so that we mark the task as stopped # register atexit, so that we mark the task as stopped
self._at_exit_called = False self._at_exit_called = False
self.__register_at_exit(self._at_exit)
@classmethod @classmethod
def current_task(cls): def current_task(cls):
@ -132,9 +132,10 @@ class Task(_Task):
:param project_name: project to create the task in (if project doesn't exist, it will be created) :param project_name: project to create the task in (if project doesn't exist, it will be created)
:param task_name: task name to be created (in development mode, not when running remotely) :param task_name: task name to be created (in development mode, not when running remotely)
:param task_type: task type to be created (in development mode, not when running remotely) :param task_type: task type to be created (in development mode, not when running remotely)
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \ :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
if False every time we call the function we create a new task with the same name \ if False every time we call the function we create a new task with the same name
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \ Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
Note: A closed or published task will not be reused, and a new task will be created. Note: A closed or published task will not be reused, and a new task will be created.
:param output_uri: Default location for output models (currently support folder/S3/GS/ ). :param output_uri: Default location for output models (currently support folder/S3/GS/ ).
notice: sub-folders (task_id) is created in the destination folder for all outputs. notice: sub-folders (task_id) is created in the destination folder for all outputs.
@ -166,12 +167,31 @@ class Task(_Task):
) )
if cls.__main_task is not None: if cls.__main_task is not None:
# if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure we do not wait for the repo detect thread
cls.__main_task._detect_repo_async_thread = None
# remove the logger from the previous process
logger = cls.__main_task.get_logger()
logger.set_flush_period(None)
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
cls.__main_task._reporter = None
cls.__main_task.get_logger()
# unregister signal hooks, they cause subprocess to hang
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
if not running_remotely(): if not running_remotely():
verify_defaults_match() verify_defaults_match()
return cls.__main_task return cls.__main_task
# check that we are not a child process, in that case do nothing # check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
class _TaskStub(object): class _TaskStub(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@ -212,9 +232,10 @@ class Task(_Task):
raise raise
else: else:
Task.__main_task = task Task.__main_task = task
# Patch argparse to be aware of the current task # register the main task for at exit hooks (there should only be one)
argparser_update_currenttask(Task.__main_task) task.__register_at_exit(task._at_exit)
EnvironmentBind.update_current_task(Task.__main_task) # patch OS forking
PatchOsFork.patch_fork()
if auto_connect_frameworks: if auto_connect_frameworks:
PatchedMatplotlib.update_current_task(Task.__main_task) PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task)
@ -227,21 +248,19 @@ class Task(_Task):
if auto_resource_monitoring: if auto_resource_monitoring:
task._resource_monitor = ResourceMonitor(task) task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start() task._resource_monitor.start()
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# make sure all random generators are initialized with new seed # make sure all random generators are initialized with new seed
make_deterministic(task.get_random_seed()) make_deterministic(task.get_random_seed())
if auto_connect_arg_parser: if auto_connect_arg_parser:
EnvironmentBind.update_current_task(Task.__main_task)
# Patch ArgParser to be aware of the current task # Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task) argparser_update_currenttask(Task.__main_task)
# Check if parse args already called. If so, sync task parameters with parser # Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called(): if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args() parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser, parsed_args=parsed_args) task._connect_argparse(parser=parser, parsed_args=parsed_args)
# Make sure we start the logger, it will patch the main logging object and pipe all output # Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger. # if we are running locally and using development mode worker, we will pipe all stdout to logger.
@ -339,7 +358,9 @@ class Task(_Task):
in_dev_mode = not running_remotely() in_dev_mode = not running_remotely()
if in_dev_mode: if in_dev_mode:
if not reuse_last_task_id or not cls.__task_is_relevant(default_task): if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None default_task_id = None
closed_old_task = cls.__close_timed_out_task(default_task) closed_old_task = cls.__close_timed_out_task(default_task)
else: else:
@ -600,6 +621,9 @@ class Task(_Task):
""" """
self._at_exit() self._at_exit()
self._at_exit_called = False self._at_exit_called = False
# unregister atexit callbacks and signal hooks, if we are the main task
if self.is_main_task():
self.__register_at_exit(None)
def is_current_task(self): def is_current_task(self):
""" """
@ -914,9 +938,12 @@ class Task(_Task):
Will happen automatically once we exit code, i.e. atexit Will happen automatically once we exit code, i.e. atexit
:return: :return:
""" """
# protect sub-process at_exit
if self._at_exit_called: if self._at_exit_called:
return return
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# from here do not get into watch dog # from here do not get into watch dog
@ -948,28 +975,32 @@ class Task(_Task):
# from here, do not send log in background thread # from here, do not send log in background thread
if wait_for_uploads: if wait_for_uploads:
self.flush(wait_for_uploads=True) self.flush(wait_for_uploads=True)
# wait until the reporter flush everything
self.reporter.stop()
if print_done_waiting: if print_done_waiting:
self.log.info('Finished uploading') self.log.info('Finished uploading')
else: else:
self._logger._flush_stdout_handler() self._logger._flush_stdout_handler()
# from here, do not check worker status if not is_sub_process:
if self._dev_worker: # from here, do not check worker status
self._dev_worker.unregister() if self._dev_worker:
self._dev_worker.unregister()
# change task status # change task status
if not task_status: if not task_status:
pass pass
elif task_status[0] == 'failed': elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1]) self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed': elif task_status[0] == 'completed':
self.completed() self.completed()
elif task_status[0] == 'stopped': elif task_status[0] == 'stopped':
self.stopped() self.stopped()
# stop resource monitoring # stop resource monitoring
if self._resource_monitor: if self._resource_monitor:
self._resource_monitor.stop() self._resource_monitor.stop()
self._logger.set_flush_period(None) self._logger.set_flush_period(None)
# this is so in theory we can close a main task and start a new one # this is so in theory we can close a main task and start a new one
Task.__main_task = None Task.__main_task = None
@ -978,7 +1009,7 @@ class Task(_Task):
pass pass
@classmethod @classmethod
def __register_at_exit(cls, exit_callback): def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
class ExitHooks(object): class ExitHooks(object):
_orig_exit = None _orig_exit = None
_orig_exc_handler = None _orig_exc_handler = None
@ -1000,7 +1031,21 @@ class Task(_Task):
except Exception: except Exception:
pass pass
self._exit_callback = callback self._exit_callback = callback
atexit.register(self._exit_callback) if callback:
self.hook()
else:
# un register int hook
print('removing int hook', self._orig_exc_handler)
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for s in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, self._org_handlers[s])
except Exception:
pass
self._org_handlers = {}
def hook(self): def hook(self):
if self._orig_exit is None: if self._orig_exit is None:
@ -1009,20 +1054,23 @@ class Task(_Task):
if self._orig_exc_handler is None: if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler sys.excepthook = self.exc_handler
atexit.register(self._exit_callback) if self._exit_callback:
if sys.platform == 'win32': atexit.register(self._exit_callback)
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE] if self._org_handlers:
else: if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] signal.SIGILL, signal.SIGFPE]
for s in catch_signals: else:
# noinspection PyBroadException catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
try: signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
self._org_handlers[s] = signal.getsignal(s) for s in catch_signals:
signal.signal(s, self.signal_handler) # noinspection PyBroadException
except Exception: try:
pass self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
def exit(self, code=0): def exit(self, code=0):
self.exit_code = code self.exit_code = code
@ -1077,6 +1125,22 @@ class Task(_Task):
# return handler result # return handler result
return org_handler return org_handler
# we only remove the signals since this will hang subprocesses
if only_remove_signal_and_exception_hooks:
if not cls.__exit_hook:
return
if cls.__exit_hook._orig_exc_handler:
sys.excepthook = cls.__exit_hook._orig_exc_handler
cls.__exit_hook._orig_exc_handler = None
for s in cls.__exit_hook._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, cls.__exit_hook._org_handlers[s])
except Exception:
pass
cls.__exit_hook._org_handlers = {}
return
if cls.__exit_hook is None: if cls.__exit_hook is None:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1084,13 +1148,13 @@ class Task(_Task):
cls.__exit_hook.hook() cls.__exit_hook.hook()
except Exception: except Exception:
cls.__exit_hook = None cls.__exit_hook = None
elif cls.__main_task is None: else:
cls.__exit_hook.update_callback(exit_callback) cls.__exit_hook.update_callback(exit_callback)
@classmethod @classmethod
def __get_task(cls, task_id=None, project_name=None, task_name=None): def __get_task(cls, task_id=None, project_name=None, task_name=None):
if task_id: if task_id:
return cls(private=cls.__create_protection, task_id=task_id) return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
res = cls._send( res = cls._send(
cls._get_default_session(), cls._get_default_session(),

View File

@ -1,3 +1,4 @@
import os
import time import time
from threading import Lock from threading import Lock
@ -6,7 +7,8 @@ import six
class AsyncManagerMixin(object): class AsyncManagerMixin(object):
_async_results_lock = Lock() _async_results_lock = Lock()
_async_results = [] # per pid (process) list of async jobs (support for sub-processes forking)
_async_results = {}
@classmethod @classmethod
def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None): def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None):
@ -14,8 +16,9 @@ class AsyncManagerMixin(object):
try: try:
cls._async_results_lock.acquire() cls._async_results_lock.acquire()
# discard completed results # discard completed results
cls._async_results = [r for r in cls._async_results if not r.ready()] pid = os.getpid()
num_results = len(cls._async_results) cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()]
num_results = len(cls._async_results[pid])
if wait_on_max_results is not None and num_results >= wait_on_max_results: if wait_on_max_results is not None and num_results >= wait_on_max_results:
# At least max_results results are still pending, wait # At least max_results results are still pending, wait
if wait_cb: if wait_cb:
@ -25,7 +28,9 @@ class AsyncManagerMixin(object):
continue continue
# add result # add result
if result and not result.ready(): if result and not result.ready():
cls._async_results.append(result) if not cls._async_results.get(pid):
cls._async_results[pid] = []
cls._async_results[pid].append(result)
break break
finally: finally:
cls._async_results_lock.release() cls._async_results_lock.release()
@ -34,7 +39,8 @@ class AsyncManagerMixin(object):
def wait_for_results(cls, timeout=None, max_num_uploads=None): def wait_for_results(cls, timeout=None, max_num_uploads=None):
remaining = timeout remaining = timeout
count = 0 count = 0
for r in cls._async_results: pid = os.getpid()
for r in cls._async_results.get(pid, []):
if r.ready(): if r.ready():
continue continue
t = time.time() t = time.time()
@ -48,13 +54,14 @@ class AsyncManagerMixin(object):
if max_num_uploads is not None and max_num_uploads - count <= 0: if max_num_uploads is not None and max_num_uploads - count <= 0:
break break
if timeout is not None: if timeout is not None:
remaining = max(0, remaining - max(0, time.time() - t)) remaining = max(0., remaining - max(0., time.time() - t))
if not remaining: if not remaining:
break break
@classmethod @classmethod
def get_num_results(cls): def get_num_results(cls):
if cls._async_results is not None: pid = os.getpid()
return len([r for r in cls._async_results if not r.ready()]) if cls._async_results.get(pid, []):
return len([r for r in cls._async_results.get(pid, []) if not r.ready()])
else: else:
return 0 return 0

View File

@ -1 +1 @@
__version__ = '0.10.2' __version__ = '0.10.3rc1'