mirror of
https://github.com/clearml/clearml
synced 2025-04-19 05:44:42 +00:00
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
94642af4a1
@ -1,7 +1,13 @@
|
||||
# TRAINS SDK configuration file
|
||||
api {
|
||||
# Notice: 'host' is the api server (default port 8008), not the web server.
|
||||
host: http://localhost:8008
|
||||
# web_server on port 8080
|
||||
web_server: "http://localhost:8080"
|
||||
|
||||
# Notice: 'api_server' is the api server (default port 8008), not the web server.
|
||||
api_server: "http://localhost:8008"
|
||||
|
||||
# file server onport 8081
|
||||
files_server: "http://localhost:8081"
|
||||
|
||||
# Credentials are generated in the webapp, http://localhost:8080/admin
|
||||
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
|
||||
|
@ -1,7 +1,8 @@
|
||||
# TRAINS - Example of Matplotlib integration and reporting
|
||||
# TRAINS - Example of Matplotlib and Seaborn integration and reporting
|
||||
#
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from trains import Task
|
||||
|
||||
|
||||
@ -33,4 +34,13 @@ plt.imshow(m)
|
||||
plt.title('Image Title')
|
||||
plt.show()
|
||||
|
||||
print('This is a Matplotlib example')
|
||||
sns.set(style="darkgrid")
|
||||
# Load an example dataset with long-form data
|
||||
fmri = sns.load_dataset("fmri")
|
||||
# Plot the responses for different events and regions
|
||||
sns.lineplot(x="timepoint", y="signal",
|
||||
hue="region", style="event",
|
||||
data=fmri)
|
||||
plt.show()
|
||||
|
||||
print('This is a Matplotlib & Seaborn example')
|
||||
|
@ -1,7 +1,11 @@
|
||||
{
|
||||
version: 1.5
|
||||
# default https://demoapi.trainsai.io host
|
||||
host: ""
|
||||
# default api_server: https://demoapi.trainsai.io
|
||||
api_server: ""
|
||||
# default web_server: https://demoapp.trainsai.io
|
||||
web_server: ""
|
||||
# default files_server: https://demofiles.trainsai.io
|
||||
files_server: ""
|
||||
|
||||
# verify host ssl certificate, set to False only if you have a very good reason
|
||||
verify_certificate: True
|
||||
|
@ -2,6 +2,8 @@ from ...backend_config import EnvEntry
|
||||
|
||||
|
||||
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
|
||||
ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "ALG_WEB_HOST")
|
||||
ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "ALG_FILES_HOST")
|
||||
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
|
||||
|
@ -2,6 +2,7 @@ import json as json_lib
|
||||
import sys
|
||||
import types
|
||||
from socket import gethostname
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
@ -10,11 +11,11 @@ from pyhocon import ConfigTree
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from .callresult import CallResult
|
||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
|
||||
from .request import Request, BatchRequest
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ..version import __version__
|
||||
|
||||
|
||||
@ -32,11 +33,13 @@ class Session(TokenManager):
|
||||
|
||||
_async_status_code = 202
|
||||
_session_requests = 0
|
||||
_session_initial_timeout = (1.0, 10)
|
||||
_session_timeout = (5.0, None)
|
||||
_session_initial_timeout = (3.0, 10.)
|
||||
_session_timeout = (5.0, 300.)
|
||||
|
||||
api_version = '2.1'
|
||||
default_host = "https://demoapi.trainsai.io"
|
||||
default_web = "https://demoapp.trainsai.io"
|
||||
default_files = "https://demofiles.trainsai.io"
|
||||
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
@ -97,7 +100,7 @@ class Session(TokenManager):
|
||||
self._logger = logger
|
||||
|
||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||
default=(self.config.get("api.credentials.access_key") or self.default_key)
|
||||
default=(self.config.get("api.credentials.access_key", None) or self.default_key)
|
||||
)
|
||||
if not self.access_key:
|
||||
raise ValueError(
|
||||
@ -105,7 +108,7 @@ class Session(TokenManager):
|
||||
)
|
||||
|
||||
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
||||
default=(self.config.get("api.credentials.secret_key") or self.default_secret)
|
||||
default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
|
||||
)
|
||||
if not self.secret_key:
|
||||
raise ValueError(
|
||||
@ -125,7 +128,7 @@ class Session(TokenManager):
|
||||
|
||||
self.__worker = worker or gethostname()
|
||||
|
||||
self.__max_req_size = self.config.get("api.http.max_req_size")
|
||||
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
||||
if not self.__max_req_size:
|
||||
raise ValueError("missing max request size")
|
||||
|
||||
@ -140,6 +143,11 @@ class Session(TokenManager):
|
||||
except (jwt.DecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# now setup the session reporting, so one consecutive retries will show warning
|
||||
# we do that here, so if we have problems authenticating, we see them immediately
|
||||
# notice: this is across the board warning omission
|
||||
urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3)
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
service,
|
||||
@ -394,7 +402,65 @@ class Session(TokenManager):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
|
||||
return ENV_HOST.get(default=(config.get("api.api_server", None) or
|
||||
config.get("api.host", None) or cls.default_host))
|
||||
|
||||
@classmethod
|
||||
def get_app_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
|
||||
# get from config/environment
|
||||
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
|
||||
if web_host:
|
||||
return web_host
|
||||
|
||||
# return default
|
||||
host = cls.get_api_server_host(config)
|
||||
if host == cls.default_host:
|
||||
return cls.default_web
|
||||
|
||||
# compose ourselves
|
||||
if '://demoapi.' in host:
|
||||
return host.replace('://demoapi.', '://demoapp.', 1)
|
||||
if '://api.' in host:
|
||||
return host.replace('://api.', '://app.', 1)
|
||||
|
||||
parsed = urlparse(host)
|
||||
if parsed.port == 8008:
|
||||
return host.replace(':8008', ':8080', 1)
|
||||
|
||||
raise ValueError('Could not detect TRAINS web application server')
|
||||
|
||||
@classmethod
|
||||
def get_files_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
# get from config/environment
|
||||
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
|
||||
if files_host:
|
||||
return files_host
|
||||
|
||||
# return default
|
||||
host = cls.get_api_server_host(config)
|
||||
if host == cls.default_host:
|
||||
return cls.default_files
|
||||
|
||||
# compose ourselves
|
||||
app_host = cls.get_app_server_host(config)
|
||||
parsed = urlparse(app_host)
|
||||
if parsed.port:
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
||||
elif parsed.netloc.startswith('demoapp.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
||||
elif parsed.netloc.startswith('app.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
||||
else:
|
||||
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
|
||||
|
||||
return urlunparse(parsed)
|
||||
|
||||
def _do_refresh_token(self, old_token, exp=None):
|
||||
""" TokenManager abstract method implementation.
|
||||
|
@ -27,6 +27,29 @@ def get_config():
|
||||
return config_obj
|
||||
|
||||
|
||||
def urllib_log_warning_setup(total_retries=10, display_warning_after=5):
|
||||
class RetryFilter(logging.Filter):
|
||||
last_instance = None
|
||||
|
||||
def __init__(self, total, warning_after=5):
|
||||
super(RetryFilter, self).__init__()
|
||||
self.total = total
|
||||
self.display_warning_after = warning_after
|
||||
self.last_instance = self
|
||||
|
||||
def filter(self, record):
|
||||
if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry):
|
||||
retry_left = self.total - record.args[0].total
|
||||
return retry_left >= self.display_warning_after
|
||||
|
||||
return True
|
||||
|
||||
urllib3_log = logging.getLogger('urllib3.connectionpool')
|
||||
if urllib3_log:
|
||||
urllib3_log.removeFilter(RetryFilter.last_instance)
|
||||
urllib3_log.addFilter(RetryFilter(total_retries, display_warning_after))
|
||||
|
||||
|
||||
class TLSv1HTTPAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||
self.poolmanager = PoolManager(num_pools=connections,
|
||||
|
@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface):
|
||||
if log:
|
||||
log.error(error_msg)
|
||||
|
||||
if res.meta.result_code <= 500:
|
||||
# Proper backend error/bad status code - raise or return
|
||||
if raise_on_errors:
|
||||
raise SendError(res, error_msg)
|
||||
return res
|
||||
|
||||
except requests.exceptions.BaseHTTPError as e:
|
||||
log.error('failed sending %s: %s' % (str(req), str(e)))
|
||||
res = None
|
||||
log.error('Failed sending %s: %s' % (str(req), str(e)))
|
||||
except Exception as e:
|
||||
res = None
|
||||
log.error('Failed sending %s: %s' % (str(req), str(e)))
|
||||
|
||||
# Infrastructure error
|
||||
if log:
|
||||
log.info('retrying request %s' % str(req))
|
||||
if res and res.meta.result_code <= 500:
|
||||
# Proper backend error/bad status code - raise or return
|
||||
if raise_on_errors:
|
||||
raise SendError(res, error_msg)
|
||||
return res
|
||||
|
||||
# # Infrastructure error
|
||||
# if log:
|
||||
# log.info('retrying request %s' % str(req))
|
||||
|
||||
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
|
||||
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
|
||||
|
@ -1,8 +1,8 @@
|
||||
import collections
|
||||
import json
|
||||
|
||||
import cv2
|
||||
import six
|
||||
from threading import Thread, Event
|
||||
|
||||
from ..base import InterfaceBase
|
||||
from ..setupuploadmixin import SetupUploadMixin
|
||||
@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
self._bucket_config = None
|
||||
self._storage_uri = None
|
||||
self._async_enable = async_enable
|
||||
self._flush_frequency = 30.0
|
||||
self._exit_flag = False
|
||||
self._flush_event = Event()
|
||||
self._flush_event.clear()
|
||||
self._thread = Thread(target=self._daemon)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
|
||||
def _set_storage_uri(self, value):
|
||||
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
|
||||
@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
def async_enable(self, value):
|
||||
self._async_enable = bool(value)
|
||||
|
||||
def _daemon(self):
|
||||
while not self._exit_flag:
|
||||
self._flush_event.wait(self._flush_frequency)
|
||||
self._flush_event.clear()
|
||||
self._write()
|
||||
# wait for all reports
|
||||
if self.get_num_results() > 0:
|
||||
self.wait_for_results()
|
||||
|
||||
def _report(self, ev):
|
||||
self._events.append(ev)
|
||||
if len(self._events) >= self._flush_threshold:
|
||||
self._write()
|
||||
self.flush()
|
||||
|
||||
def _write(self):
|
||||
if not self._events:
|
||||
@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
"""
|
||||
Flush cached reports to backend.
|
||||
"""
|
||||
self._write()
|
||||
# wait for all reports
|
||||
if self.get_num_results() > 0:
|
||||
self.wait_for_results()
|
||||
self._flush_event.set()
|
||||
|
||||
def stop(self):
|
||||
self._exit_flag = True
|
||||
self._flush_event.set()
|
||||
self._thread.join()
|
||||
|
||||
def report_scalar(self, title, series, value, iter):
|
||||
"""
|
||||
|
@ -77,6 +77,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
|
||||
self._edit_lock = RLock()
|
||||
super(Task, self).__init__(id=task_id, session=session, log=log)
|
||||
self._project_name = None
|
||||
self._storage_uri = None
|
||||
self._input_model = None
|
||||
self._output_model = None
|
||||
@ -87,6 +88,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._parameters_allowed_types = (
|
||||
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
|
||||
)
|
||||
self._app_server = None
|
||||
self._files_server = None
|
||||
|
||||
if not task_id:
|
||||
# generate a new task
|
||||
@ -656,8 +659,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if self.project is None:
|
||||
return None
|
||||
|
||||
if self._project_name and self._project_name[0] == self.project:
|
||||
return self._project_name[1]
|
||||
|
||||
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
||||
return res.response.project.name
|
||||
self._project_name = (self.project, res.response.project.name)
|
||||
return self._project_name[1]
|
||||
|
||||
def get_tags(self):
|
||||
return self._get_task_property("tags")
|
||||
@ -668,33 +675,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._edit(tags=self.data.tags)
|
||||
|
||||
def _get_default_report_storage_uri(self):
|
||||
app_host = self._get_app_server()
|
||||
parsed = urlparse(app_host)
|
||||
if parsed.port:
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
||||
elif parsed.netloc.startswith('demoapp.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
||||
elif parsed.netloc.startswith('app.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
||||
else:
|
||||
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
||||
return urlunparse(parsed)
|
||||
if not self._files_server:
|
||||
self._files_server = Session.get_files_server_host()
|
||||
return self._files_server
|
||||
|
||||
@classmethod
|
||||
def _get_api_server(cls):
|
||||
return Session.get_api_server_host()
|
||||
|
||||
@classmethod
|
||||
def _get_app_server(cls):
|
||||
host = cls._get_api_server()
|
||||
if '://demoapi.' in host:
|
||||
return host.replace('://demoapi.', '://demoapp.', 1)
|
||||
if '://api.' in host:
|
||||
return host.replace('://api.', '://app.', 1)
|
||||
|
||||
parsed = urlparse(host)
|
||||
if parsed.port == 8008:
|
||||
return host.replace(':8008', ':8080', 1)
|
||||
def _get_app_server(self):
|
||||
if not self._app_server:
|
||||
self._app_server = Session.get_app_server_host()
|
||||
return self._app_server
|
||||
|
||||
def _edit(self, **kwargs):
|
||||
with self._edit_lock:
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
from ..config import TASK_LOG_ENVIRONMENT, running_remotely
|
||||
|
||||
|
||||
@ -34,3 +36,43 @@ class EnvironmentBind(object):
|
||||
if running_remotely():
|
||||
# put back into os:
|
||||
os.environ.update(env_param)
|
||||
|
||||
|
||||
class PatchOsFork(object):
|
||||
_original_fork = None
|
||||
|
||||
@classmethod
|
||||
def patch_fork(cls):
|
||||
# only once
|
||||
if cls._original_fork:
|
||||
return
|
||||
if six.PY2:
|
||||
cls._original_fork = staticmethod(os.fork)
|
||||
else:
|
||||
cls._original_fork = os.fork
|
||||
os.fork = cls._patched_fork
|
||||
|
||||
@staticmethod
|
||||
def _patched_fork(*args, **kwargs):
|
||||
ret = PatchOsFork._original_fork(*args, **kwargs)
|
||||
# Make sure the new process stdout is logged
|
||||
if not ret:
|
||||
from ..task import Task
|
||||
if Task.current_task() is not None:
|
||||
# bind sub-process logger
|
||||
task = Task.init()
|
||||
task.get_logger().flush()
|
||||
|
||||
# if we got here patch the os._exit of our instance to call us
|
||||
def _at_exit_callback(*args, **kwargs):
|
||||
# call at exit manually
|
||||
# noinspection PyProtectedMember
|
||||
task._at_exit()
|
||||
# noinspection PyProtectedMember
|
||||
return os._org_exit(*args, **kwargs)
|
||||
|
||||
if not hasattr(os, '_org_exit'):
|
||||
os._org_exit = os._exit
|
||||
os._exit = _at_exit_callback
|
||||
|
||||
return ret
|
||||
|
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
|
||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||
Trains events and reports the events (metrics) for an Trains task (logger).
|
||||
"""
|
||||
_add_lock = threading.Lock()
|
||||
_add_lock = threading.RLock()
|
||||
_series_name_lookup = {}
|
||||
|
||||
# store all the created tensorboard writers in the system
|
||||
# this allows us to as weather a certain tile/series already exist on some EventWriter
|
||||
# and if it does, then we add to the series name the last token from the logdir
|
||||
# (so we can differentiate between the two)
|
||||
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
|
||||
_title_series_writers_lookup = {}
|
||||
_event_writers_id_to_logdir = {}
|
||||
|
||||
@property
|
||||
def variants(self):
|
||||
return self._variants
|
||||
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
|
||||
def prepare_report(self):
|
||||
return self.variants.copy()
|
||||
|
||||
@staticmethod
|
||||
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
||||
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
|
||||
logdir_header='series'):
|
||||
"""
|
||||
Split a tf.summary tag line to variant and metric.
|
||||
Variant is the first part of the split tag, metric is the second.
|
||||
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
|
||||
:param str split_char: a character to split the tag on
|
||||
:param str join_char: a character to join the the splits
|
||||
:param str default_title: variant to use in case no variant can be inferred automatically
|
||||
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
|
||||
if 'title_last' then title=header title, if 'title' then title=title header
|
||||
:return: (str, str) variant and metric
|
||||
"""
|
||||
splitted_tag = tag.split(split_char)
|
||||
series = join_char.join(splitted_tag[-num_split_parts:])
|
||||
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
||||
|
||||
# check if we already decided that we need to change the title/series
|
||||
graph_id = hash((title, series))
|
||||
if graph_id in self._graph_name_lookup:
|
||||
return self._graph_name_lookup[graph_id]
|
||||
|
||||
# check if someone other than us used this combination
|
||||
with self._add_lock:
|
||||
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
|
||||
if not event_writer_id:
|
||||
# put us there
|
||||
self._title_series_writers_lookup[graph_id] = self._id
|
||||
elif event_writer_id != self._id:
|
||||
# if there is someone else, change our series name and store us
|
||||
org_series = series
|
||||
org_title = title
|
||||
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
|
||||
split_logddir = self._logdir.split(os.path.sep)
|
||||
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
|
||||
header = '/'.join(s for s in split_logddir if s in unique_logdir)
|
||||
if logdir_header == 'series_last':
|
||||
series = header + ': ' + series
|
||||
elif logdir_header == 'series':
|
||||
series = series + ' :' + header
|
||||
elif logdir_header == 'title':
|
||||
title = title + ' ' + header
|
||||
else: # logdir_header == 'title_last':
|
||||
title = header + ' ' + title
|
||||
graph_id = hash((title, series))
|
||||
# check if for some reason the new series is already occupied
|
||||
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
|
||||
if new_event_writer_id is not None and new_event_writer_id != self._id:
|
||||
# well that's about it, nothing else we could do
|
||||
if logdir_header == 'series_last':
|
||||
series = str(self._logdir) + ': ' + org_series
|
||||
elif logdir_header == 'series':
|
||||
series = org_series + ' :' + str(self._logdir)
|
||||
elif logdir_header == 'title':
|
||||
title = org_title + ' ' + str(self._logdir)
|
||||
else: # logdir_header == 'title_last':
|
||||
title = str(self._logdir) + ' ' + org_title
|
||||
graph_id = hash((title, series))
|
||||
|
||||
self._title_series_writers_lookup[graph_id] = self._id
|
||||
|
||||
# store for next time
|
||||
self._graph_name_lookup[graph_id] = (title, series)
|
||||
return title, series
|
||||
|
||||
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
|
||||
histogram_granularity=50, max_keep_images=None):
|
||||
def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
|
||||
histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
|
||||
"""
|
||||
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
|
||||
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
|
||||
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
|
||||
"""
|
||||
# We are the events_writer, so that's what we'll pass
|
||||
IsTensorboardInit.set_tensorboard_used()
|
||||
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
|
||||
self._id = hash(self._logdir)
|
||||
self._event_writers_id_to_logdir[self._id] = self._logdir
|
||||
self.max_keep_images = max_keep_images
|
||||
self.report_freq = report_freq
|
||||
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
||||
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
|
||||
self._hist_report_cache = {}
|
||||
self._hist_x_granularity = 50
|
||||
self._max_step = 0
|
||||
self._graph_name_lookup = {}
|
||||
|
||||
def _decode_image(self, img_str, width, height, color_channels):
|
||||
# noinspection PyBroadException
|
||||
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
|
||||
if img_data_np is None:
|
||||
return
|
||||
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
|
||||
if img_data_np.dtype != np.uint8:
|
||||
# assume scale 0-1
|
||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
|
||||
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
|
||||
|
||||
def _add_scalar(self, tag, step, scalar_data):
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
|
||||
|
||||
# update scalar cache
|
||||
num, value = self._scalar_report_cache.get((title, series), (0, 0))
|
||||
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
|
||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||
# Z-axis actual value (interpolated 'bucket')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||
logdir_header='series')
|
||||
|
||||
# get histograms from cache
|
||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
|
||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||
if not self.trains:
|
||||
try:
|
||||
logdir = self.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.trains.add_event(*args, **kwargs)
|
||||
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
|
||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||
if not self.trains:
|
||||
try:
|
||||
logdir = self.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
**PatchSummaryToEventTransformer.defaults_dict)
|
||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.trains.add_event(*args, **kwargs)
|
||||
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
|
||||
|
||||
# patch the events writer field, and add a double Event Logger (Trains and original)
|
||||
base_eventwriter = __dict__['event_writer']
|
||||
try:
|
||||
logdir = base_eventwriter.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
|
||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
logdir=logdir, **defaults_dict)
|
||||
|
||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
|
||||
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _get_event_writer():
|
||||
def _get_event_writer(writer):
|
||||
if not PatchTensorFlowEager.__main_task:
|
||||
return None
|
||||
if PatchTensorFlowEager.__trains_event_writer is None:
|
||||
try:
|
||||
logdir = writer.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
||||
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
|
||||
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||
**PatchTensorFlowEager.defaults_dict)
|
||||
return PatchTensorFlowEager.__trains_event_writer
|
||||
|
||||
@staticmethod
|
||||
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
||||
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
||||
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer:
|
||||
try:
|
||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
||||
@ -1350,93 +1431,3 @@ class PatchTensorflowModelIO(object):
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
class PatchPyTorchModelIO(object):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchPyTorchModelIO.__main_task = task
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
if PatchPyTorchModelIO.__patched:
|
||||
return
|
||||
|
||||
if 'torch' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched = True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# hack: make sure tensorflow.__init__ is called
|
||||
import torch
|
||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # print('Failed patching pytorch')
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
@ -11,19 +11,20 @@ from trains.config import config_obj
|
||||
|
||||
|
||||
description = """
|
||||
Please create new credentials using the web app: {}/admin
|
||||
Please create new credentials using the web app: {}/profile
|
||||
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
||||
|
||||
Paste credentials here: """
|
||||
|
||||
try:
|
||||
def_host = ENV_HOST.get(default=config_obj.get("api.host"))
|
||||
def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080'
|
||||
except Exception:
|
||||
def_host = 'http://localhost:8080'
|
||||
|
||||
host_description = """
|
||||
Editing configuration file: {CONFIG_FILE}
|
||||
Enter the url of the trains-server's api service, for example: http://localhost:8008 : """.format(
|
||||
Enter the url of the trains-server's Web service, for example: {HOST}
|
||||
""".format(
|
||||
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
|
||||
HOST=def_host,
|
||||
)
|
||||
@ -37,64 +38,60 @@ def main():
|
||||
print('Leaving setup, feel free to edit the configuration file.')
|
||||
return
|
||||
|
||||
print(host_description, end='')
|
||||
parsed_host = None
|
||||
while not parsed_host:
|
||||
parse_input = input()
|
||||
if not parse_input:
|
||||
parse_input = def_host
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
|
||||
parse_input = 'http://'+parse_input
|
||||
parsed_host = urlparse(parse_input)
|
||||
if parsed_host.scheme not in ('http', 'https'):
|
||||
parsed_host = None
|
||||
except Exception:
|
||||
parsed_host = None
|
||||
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
|
||||
print(host_description)
|
||||
web_host = input_url('Web Application Host', '')
|
||||
parsed_host = verify_url(web_host)
|
||||
|
||||
if parsed_host.port == 8080:
|
||||
# this is a docker 8080 is the web address, we need the api address, it is 8008
|
||||
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
|
||||
if parsed_host.port == 8008:
|
||||
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port == 8080:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapp.'):
|
||||
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
|
||||
parsed_host.netloc))
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
|
||||
parsed_host.netloc))
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.port == 8008:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapi.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
|
||||
parsed_host.netloc))
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||
else:
|
||||
api_host = None
|
||||
web_host = None
|
||||
api_host = ''
|
||||
web_host = ''
|
||||
files_host = ''
|
||||
if not parsed_host.port:
|
||||
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
|
||||
replace_port = input().lower()
|
||||
if not replace_port or replace_port == 'y' or replace_port == 'yes':
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
|
||||
if not api_host:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
if not web_host:
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
|
||||
print('Host configured to: {}'.format(api_host))
|
||||
api_host = input_url('API Host', api_host)
|
||||
files_host = input_url('File Store Host', files_host)
|
||||
|
||||
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
|
||||
api_host, web_host, files_host))
|
||||
|
||||
print(description.format(web_host), end='')
|
||||
parse_input = input()
|
||||
@ -133,11 +130,14 @@ def main():
|
||||
header = '# TRAINS SDK configuration file\n' \
|
||||
'api {\n' \
|
||||
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
|
||||
' host: %s\n' \
|
||||
' # Credentials are generated in the webapp, %s/admin\n' \
|
||||
' api_server: %s\n' \
|
||||
' web_server: %s\n' \
|
||||
' files_server: %s\n' \
|
||||
' # Credentials are generated in the webapp, %s/profile\n' \
|
||||
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
|
||||
'}\n' \
|
||||
'sdk ' % (api_host, web_host, credentials['access_key'], credentials['secret_key'])
|
||||
'sdk ' % (api_host, web_host, files_host,
|
||||
web_host, credentials['access_key'], credentials['secret_key'])
|
||||
f.write(header)
|
||||
f.write(default_sdk)
|
||||
except Exception:
|
||||
@ -148,5 +148,30 @@ def main():
|
||||
print('TRAINS setup completed successfully.')
|
||||
|
||||
|
||||
def input_url(host_type, host=None):
|
||||
while True:
|
||||
print('{} configured to: [{}] '.format(host_type, host), end='')
|
||||
parse_input = input()
|
||||
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
|
||||
break
|
||||
if parse_input and verify_url(parse_input):
|
||||
host = parse_input
|
||||
break
|
||||
return host
|
||||
|
||||
|
||||
def verify_url(parse_input):
|
||||
try:
|
||||
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
|
||||
parse_input = 'http://' + parse_input
|
||||
parsed_host = urlparse(parse_input)
|
||||
if parsed_host.scheme not in ('http', 'https'):
|
||||
parsed_host = None
|
||||
except Exception:
|
||||
parsed_host = None
|
||||
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
|
||||
return parsed_host
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -81,11 +81,14 @@ class Logger(object):
|
||||
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Logger._stdout_original_write = sys.stdout.write
|
||||
if Logger._stdout_original_write is None:
|
||||
Logger._stdout_original_write = sys.stdout.write
|
||||
# this will only work in python 3, guard it with try/catch
|
||||
sys.stdout._original_write = sys.stdout.write
|
||||
if not hasattr(sys.stdout, '_original_write'):
|
||||
sys.stdout._original_write = sys.stdout.write
|
||||
sys.stdout.write = stdout__patched__write__
|
||||
sys.stderr._original_write = sys.stderr.write
|
||||
if not hasattr(sys.stderr, '_original_write'):
|
||||
sys.stderr._original_write = sys.stderr.write
|
||||
sys.stderr.write = stderr__patched__write__
|
||||
except Exception:
|
||||
pass
|
||||
@ -113,6 +116,7 @@ class Logger(object):
|
||||
msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||
level = logging.INFO
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
record = self._task.log.makeRecord(
|
||||
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
||||
@ -128,6 +132,7 @@ class Logger(object):
|
||||
if not omit_console:
|
||||
# if we are here and we grabbed the stdout, we need to print the real thing
|
||||
if DevWorker.report_stdout:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
Logger._stdout_original_write(str(msg)+'\n')
|
||||
@ -637,11 +642,13 @@ class Logger(object):
|
||||
@classmethod
|
||||
def _remove_std_logger(self):
|
||||
if isinstance(sys.stdout, PrintPatchLogger):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.stdout.connect(None)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(sys.stderr, PrintPatchLogger):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.stderr.connect(None)
|
||||
except Exception:
|
||||
@ -711,7 +718,13 @@ class PrintPatchLogger(object):
|
||||
|
||||
if cur_line:
|
||||
with PrintPatchLogger.recursion_protect_lock:
|
||||
self._log.console(cur_line, level=self._log_level, omit_console=True)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self._log:
|
||||
self._log.console(cur_line, level=self._log_level, omit_console=True)
|
||||
except Exception:
|
||||
# what can we do, nothing
|
||||
pass
|
||||
else:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
@ -719,8 +732,7 @@ class PrintPatchLogger(object):
|
||||
self._terminal.write(message)
|
||||
|
||||
def connect(self, logger):
|
||||
if self._log:
|
||||
self._log._flush_stdout_handler()
|
||||
self._cur_line = ''
|
||||
self._log = logger
|
||||
|
||||
def __getattr__(self, attr):
|
||||
|
154
trains/task.py
154
trains/task.py
@ -26,7 +26,7 @@ from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .task_parameters import TaskParameters
|
||||
from .binding.environ_bind import EnvironmentBind
|
||||
from .binding.environ_bind import EnvironmentBind, PatchOsFork
|
||||
from .binding.absl_bind import PatchAbsl
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
@ -66,6 +66,7 @@ class Task(_Task):
|
||||
__create_protection = object()
|
||||
__main_task = None
|
||||
__exit_hook = None
|
||||
__forked_proc_main_pid = None
|
||||
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
||||
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
|
||||
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
||||
@ -104,7 +105,6 @@ class Task(_Task):
|
||||
self._resource_monitor = None
|
||||
# register atexit, so that we mark the task as stopped
|
||||
self._at_exit_called = False
|
||||
self.__register_at_exit(self._at_exit)
|
||||
|
||||
@classmethod
|
||||
def current_task(cls):
|
||||
@ -132,9 +132,10 @@ class Task(_Task):
|
||||
:param project_name: project to create the task in (if project doesn't exist, it will be created)
|
||||
:param task_name: task name to be created (in development mode, not when running remotely)
|
||||
:param task_type: task type to be created (in development mode, not when running remotely)
|
||||
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \
|
||||
if False every time we call the function we create a new task with the same name \
|
||||
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \
|
||||
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
|
||||
if False every time we call the function we create a new task with the same name
|
||||
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
|
||||
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
|
||||
Note: A closed or published task will not be reused, and a new task will be created.
|
||||
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
|
||||
notice: sub-folders (task_id) is created in the destination folder for all outputs.
|
||||
@ -166,12 +167,31 @@ class Task(_Task):
|
||||
)
|
||||
|
||||
if cls.__main_task is not None:
|
||||
# if this is a subprocess, regardless of what the init was called for,
|
||||
# we have to fix the main task hooks and stdout bindings
|
||||
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
||||
# make sure we only do it once per process
|
||||
cls.__forked_proc_main_pid = os.getpid()
|
||||
# make sure we do not wait for the repo detect thread
|
||||
cls.__main_task._detect_repo_async_thread = None
|
||||
# remove the logger from the previous process
|
||||
logger = cls.__main_task.get_logger()
|
||||
logger.set_flush_period(None)
|
||||
# create a new logger (to catch stdout/err)
|
||||
cls.__main_task._logger = None
|
||||
cls.__main_task._reporter = None
|
||||
cls.__main_task.get_logger()
|
||||
# unregister signal hooks, they cause subprocess to hang
|
||||
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
|
||||
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
|
||||
|
||||
if not running_remotely():
|
||||
verify_defaults_match()
|
||||
|
||||
return cls.__main_task
|
||||
|
||||
# check that we are not a child process, in that case do nothing
|
||||
# check that we are not a child process, in that case do nothing.
|
||||
# we should not get here unless this is Windows platform, all others support fork
|
||||
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
||||
class _TaskStub(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -212,9 +232,10 @@ class Task(_Task):
|
||||
raise
|
||||
else:
|
||||
Task.__main_task = task
|
||||
# Patch argparse to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
EnvironmentBind.update_current_task(Task.__main_task)
|
||||
# register the main task for at exit hooks (there should only be one)
|
||||
task.__register_at_exit(task._at_exit)
|
||||
# patch OS forking
|
||||
PatchOsFork.patch_fork()
|
||||
if auto_connect_frameworks:
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
@ -227,21 +248,19 @@ class Task(_Task):
|
||||
if auto_resource_monitoring:
|
||||
task._resource_monitor = ResourceMonitor(task)
|
||||
task._resource_monitor.start()
|
||||
# Check if parse args already called. If so, sync task parameters with parser
|
||||
if argparser_parseargs_called():
|
||||
parser, parsed_args = get_argparser_last_args()
|
||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||
|
||||
# make sure all random generators are initialized with new seed
|
||||
make_deterministic(task.get_random_seed())
|
||||
|
||||
if auto_connect_arg_parser:
|
||||
EnvironmentBind.update_current_task(Task.__main_task)
|
||||
|
||||
# Patch ArgParser to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
# Check if parse args already called. If so, sync task parameters with parser
|
||||
if argparser_parseargs_called():
|
||||
parser, parsed_args = get_argparser_last_args()
|
||||
task._connect_argparse(parser, parsed_args=parsed_args)
|
||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||
|
||||
# Make sure we start the logger, it will patch the main logging object and pipe all output
|
||||
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
|
||||
@ -339,7 +358,9 @@ class Task(_Task):
|
||||
in_dev_mode = not running_remotely()
|
||||
|
||||
if in_dev_mode:
|
||||
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
|
||||
default_task_id = reuse_last_task_id
|
||||
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
default_task_id = None
|
||||
closed_old_task = cls.__close_timed_out_task(default_task)
|
||||
else:
|
||||
@ -600,6 +621,9 @@ class Task(_Task):
|
||||
"""
|
||||
self._at_exit()
|
||||
self._at_exit_called = False
|
||||
# unregister atexit callbacks and signal hooks, if we are the main task
|
||||
if self.is_main_task():
|
||||
self.__register_at_exit(None)
|
||||
|
||||
def is_current_task(self):
|
||||
"""
|
||||
@ -914,9 +938,12 @@ class Task(_Task):
|
||||
Will happen automatically once we exit code, i.e. atexit
|
||||
:return:
|
||||
"""
|
||||
# protect sub-process at_exit
|
||||
if self._at_exit_called:
|
||||
return
|
||||
|
||||
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# from here do not get into watch dog
|
||||
@ -948,28 +975,32 @@ class Task(_Task):
|
||||
# from here, do not send log in background thread
|
||||
if wait_for_uploads:
|
||||
self.flush(wait_for_uploads=True)
|
||||
# wait until the reporter flush everything
|
||||
self.reporter.stop()
|
||||
if print_done_waiting:
|
||||
self.log.info('Finished uploading')
|
||||
else:
|
||||
self._logger._flush_stdout_handler()
|
||||
|
||||
# from here, do not check worker status
|
||||
if self._dev_worker:
|
||||
self._dev_worker.unregister()
|
||||
if not is_sub_process:
|
||||
# from here, do not check worker status
|
||||
if self._dev_worker:
|
||||
self._dev_worker.unregister()
|
||||
|
||||
# change task status
|
||||
if not task_status:
|
||||
pass
|
||||
elif task_status[0] == 'failed':
|
||||
self.mark_failed(status_reason=task_status[1])
|
||||
elif task_status[0] == 'completed':
|
||||
self.completed()
|
||||
elif task_status[0] == 'stopped':
|
||||
self.stopped()
|
||||
# change task status
|
||||
if not task_status:
|
||||
pass
|
||||
elif task_status[0] == 'failed':
|
||||
self.mark_failed(status_reason=task_status[1])
|
||||
elif task_status[0] == 'completed':
|
||||
self.completed()
|
||||
elif task_status[0] == 'stopped':
|
||||
self.stopped()
|
||||
|
||||
# stop resource monitoring
|
||||
if self._resource_monitor:
|
||||
self._resource_monitor.stop()
|
||||
|
||||
self._logger.set_flush_period(None)
|
||||
# this is so in theory we can close a main task and start a new one
|
||||
Task.__main_task = None
|
||||
@ -978,7 +1009,7 @@ class Task(_Task):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __register_at_exit(cls, exit_callback):
|
||||
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
|
||||
class ExitHooks(object):
|
||||
_orig_exit = None
|
||||
_orig_exc_handler = None
|
||||
@ -1000,7 +1031,21 @@ class Task(_Task):
|
||||
except Exception:
|
||||
pass
|
||||
self._exit_callback = callback
|
||||
atexit.register(self._exit_callback)
|
||||
if callback:
|
||||
self.hook()
|
||||
else:
|
||||
# un register int hook
|
||||
print('removing int hook', self._orig_exc_handler)
|
||||
if self._orig_exc_handler:
|
||||
sys.excepthook = self._orig_exc_handler
|
||||
self._orig_exc_handler = None
|
||||
for s in self._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(s, self._org_handlers[s])
|
||||
except Exception:
|
||||
pass
|
||||
self._org_handlers = {}
|
||||
|
||||
def hook(self):
|
||||
if self._orig_exit is None:
|
||||
@ -1009,20 +1054,23 @@ class Task(_Task):
|
||||
if self._orig_exc_handler is None:
|
||||
self._orig_exc_handler = sys.excepthook
|
||||
sys.excepthook = self.exc_handler
|
||||
atexit.register(self._exit_callback)
|
||||
if sys.platform == 'win32':
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE]
|
||||
else:
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
||||
for s in catch_signals:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._org_handlers[s] = signal.getsignal(s)
|
||||
signal.signal(s, self.signal_handler)
|
||||
except Exception:
|
||||
pass
|
||||
if self._exit_callback:
|
||||
atexit.register(self._exit_callback)
|
||||
|
||||
if self._org_handlers:
|
||||
if sys.platform == 'win32':
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE]
|
||||
else:
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
||||
for s in catch_signals:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._org_handlers[s] = signal.getsignal(s)
|
||||
signal.signal(s, self.signal_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def exit(self, code=0):
|
||||
self.exit_code = code
|
||||
@ -1077,6 +1125,22 @@ class Task(_Task):
|
||||
# return handler result
|
||||
return org_handler
|
||||
|
||||
# we only remove the signals since this will hang subprocesses
|
||||
if only_remove_signal_and_exception_hooks:
|
||||
if not cls.__exit_hook:
|
||||
return
|
||||
if cls.__exit_hook._orig_exc_handler:
|
||||
sys.excepthook = cls.__exit_hook._orig_exc_handler
|
||||
cls.__exit_hook._orig_exc_handler = None
|
||||
for s in cls.__exit_hook._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(s, cls.__exit_hook._org_handlers[s])
|
||||
except Exception:
|
||||
pass
|
||||
cls.__exit_hook._org_handlers = {}
|
||||
return
|
||||
|
||||
if cls.__exit_hook is None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1084,13 +1148,13 @@ class Task(_Task):
|
||||
cls.__exit_hook.hook()
|
||||
except Exception:
|
||||
cls.__exit_hook = None
|
||||
elif cls.__main_task is None:
|
||||
else:
|
||||
cls.__exit_hook.update_callback(exit_callback)
|
||||
|
||||
@classmethod
|
||||
def __get_task(cls, task_id=None, project_name=None, task_name=None):
|
||||
if task_id:
|
||||
return cls(private=cls.__create_protection, task_id=task_id)
|
||||
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
|
||||
|
||||
res = cls._send(
|
||||
cls._get_default_session(),
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
@ -6,7 +7,8 @@ import six
|
||||
|
||||
class AsyncManagerMixin(object):
|
||||
_async_results_lock = Lock()
|
||||
_async_results = []
|
||||
# per pid (process) list of async jobs (support for sub-processes forking)
|
||||
_async_results = {}
|
||||
|
||||
@classmethod
|
||||
def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None):
|
||||
@ -14,8 +16,9 @@ class AsyncManagerMixin(object):
|
||||
try:
|
||||
cls._async_results_lock.acquire()
|
||||
# discard completed results
|
||||
cls._async_results = [r for r in cls._async_results if not r.ready()]
|
||||
num_results = len(cls._async_results)
|
||||
pid = os.getpid()
|
||||
cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()]
|
||||
num_results = len(cls._async_results[pid])
|
||||
if wait_on_max_results is not None and num_results >= wait_on_max_results:
|
||||
# At least max_results results are still pending, wait
|
||||
if wait_cb:
|
||||
@ -25,7 +28,9 @@ class AsyncManagerMixin(object):
|
||||
continue
|
||||
# add result
|
||||
if result and not result.ready():
|
||||
cls._async_results.append(result)
|
||||
if not cls._async_results.get(pid):
|
||||
cls._async_results[pid] = []
|
||||
cls._async_results[pid].append(result)
|
||||
break
|
||||
finally:
|
||||
cls._async_results_lock.release()
|
||||
@ -34,7 +39,8 @@ class AsyncManagerMixin(object):
|
||||
def wait_for_results(cls, timeout=None, max_num_uploads=None):
|
||||
remaining = timeout
|
||||
count = 0
|
||||
for r in cls._async_results:
|
||||
pid = os.getpid()
|
||||
for r in cls._async_results.get(pid, []):
|
||||
if r.ready():
|
||||
continue
|
||||
t = time.time()
|
||||
@ -48,13 +54,14 @@ class AsyncManagerMixin(object):
|
||||
if max_num_uploads is not None and max_num_uploads - count <= 0:
|
||||
break
|
||||
if timeout is not None:
|
||||
remaining = max(0, remaining - max(0, time.time() - t))
|
||||
remaining = max(0., remaining - max(0., time.time() - t))
|
||||
if not remaining:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def get_num_results(cls):
|
||||
if cls._async_results is not None:
|
||||
return len([r for r in cls._async_results if not r.ready()])
|
||||
pid = os.getpid()
|
||||
if cls._async_results.get(pid, []):
|
||||
return len([r for r in cls._async_results.get(pid, []) if not r.ready()])
|
||||
else:
|
||||
return 0
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.10.2'
|
||||
__version__ = '0.10.3rc1'
|
||||
|
Loading…
Reference in New Issue
Block a user