mirror of
https://github.com/clearml/clearml
synced 2025-05-08 06:44:26 +00:00
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
94642af4a1
@ -1,7 +1,13 @@
|
|||||||
# TRAINS SDK configuration file
|
# TRAINS SDK configuration file
|
||||||
api {
|
api {
|
||||||
# Notice: 'host' is the api server (default port 8008), not the web server.
|
# web_server on port 8080
|
||||||
host: http://localhost:8008
|
web_server: "http://localhost:8080"
|
||||||
|
|
||||||
|
# Notice: 'api_server' is the api server (default port 8008), not the web server.
|
||||||
|
api_server: "http://localhost:8008"
|
||||||
|
|
||||||
|
# file server onport 8081
|
||||||
|
files_server: "http://localhost:8081"
|
||||||
|
|
||||||
# Credentials are generated in the webapp, http://localhost:8080/admin
|
# Credentials are generated in the webapp, http://localhost:8080/admin
|
||||||
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
|
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# TRAINS - Example of Matplotlib integration and reporting
|
# TRAINS - Example of Matplotlib and Seaborn integration and reporting
|
||||||
#
|
#
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
from trains import Task
|
from trains import Task
|
||||||
|
|
||||||
|
|
||||||
@ -33,4 +34,13 @@ plt.imshow(m)
|
|||||||
plt.title('Image Title')
|
plt.title('Image Title')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
print('This is a Matplotlib example')
|
sns.set(style="darkgrid")
|
||||||
|
# Load an example dataset with long-form data
|
||||||
|
fmri = sns.load_dataset("fmri")
|
||||||
|
# Plot the responses for different events and regions
|
||||||
|
sns.lineplot(x="timepoint", y="signal",
|
||||||
|
hue="region", style="event",
|
||||||
|
data=fmri)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print('This is a Matplotlib & Seaborn example')
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
{
|
{
|
||||||
version: 1.5
|
version: 1.5
|
||||||
# default https://demoapi.trainsai.io host
|
# default api_server: https://demoapi.trainsai.io
|
||||||
host: ""
|
api_server: ""
|
||||||
|
# default web_server: https://demoapp.trainsai.io
|
||||||
|
web_server: ""
|
||||||
|
# default files_server: https://demofiles.trainsai.io
|
||||||
|
files_server: ""
|
||||||
|
|
||||||
# verify host ssl certificate, set to False only if you have a very good reason
|
# verify host ssl certificate, set to False only if you have a very good reason
|
||||||
verify_certificate: True
|
verify_certificate: True
|
||||||
|
@ -2,6 +2,8 @@ from ...backend_config import EnvEntry
|
|||||||
|
|
||||||
|
|
||||||
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
|
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
|
||||||
|
ENV_WEB_HOST = EnvEntry("TRAINS_WEB_HOST", "ALG_WEB_HOST")
|
||||||
|
ENV_FILES_HOST = EnvEntry("TRAINS_FILES_HOST", "ALG_FILES_HOST")
|
||||||
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
|
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
|
||||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
|
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
|
||||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
|
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
|
||||||
|
@ -2,6 +2,7 @@ import json as json_lib
|
|||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
from six.moves.urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import requests
|
import requests
|
||||||
@ -10,11 +11,11 @@ from pyhocon import ConfigTree
|
|||||||
from requests.auth import HTTPBasicAuth
|
from requests.auth import HTTPBasicAuth
|
||||||
|
|
||||||
from .callresult import CallResult
|
from .callresult import CallResult
|
||||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY
|
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
|
||||||
from .request import Request, BatchRequest
|
from .request import Request, BatchRequest
|
||||||
from .token_manager import TokenManager
|
from .token_manager import TokenManager
|
||||||
from ..config import load
|
from ..config import load
|
||||||
from ..utils import get_http_session_with_retry
|
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||||
from ..version import __version__
|
from ..version import __version__
|
||||||
|
|
||||||
|
|
||||||
@ -32,11 +33,13 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
_async_status_code = 202
|
_async_status_code = 202
|
||||||
_session_requests = 0
|
_session_requests = 0
|
||||||
_session_initial_timeout = (1.0, 10)
|
_session_initial_timeout = (3.0, 10.)
|
||||||
_session_timeout = (5.0, None)
|
_session_timeout = (5.0, 300.)
|
||||||
|
|
||||||
api_version = '2.1'
|
api_version = '2.1'
|
||||||
default_host = "https://demoapi.trainsai.io"
|
default_host = "https://demoapi.trainsai.io"
|
||||||
|
default_web = "https://demoapp.trainsai.io"
|
||||||
|
default_files = "https://demofiles.trainsai.io"
|
||||||
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||||
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||||
|
|
||||||
@ -97,7 +100,7 @@ class Session(TokenManager):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||||
default=(self.config.get("api.credentials.access_key") or self.default_key)
|
default=(self.config.get("api.credentials.access_key", None) or self.default_key)
|
||||||
)
|
)
|
||||||
if not self.access_key:
|
if not self.access_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -105,7 +108,7 @@ class Session(TokenManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
||||||
default=(self.config.get("api.credentials.secret_key") or self.default_secret)
|
default=(self.config.get("api.credentials.secret_key", None) or self.default_secret)
|
||||||
)
|
)
|
||||||
if not self.secret_key:
|
if not self.secret_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -125,7 +128,7 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
self.__worker = worker or gethostname()
|
self.__worker = worker or gethostname()
|
||||||
|
|
||||||
self.__max_req_size = self.config.get("api.http.max_req_size")
|
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
||||||
if not self.__max_req_size:
|
if not self.__max_req_size:
|
||||||
raise ValueError("missing max request size")
|
raise ValueError("missing max request size")
|
||||||
|
|
||||||
@ -140,6 +143,11 @@ class Session(TokenManager):
|
|||||||
except (jwt.DecodeError, ValueError):
|
except (jwt.DecodeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# now setup the session reporting, so one consecutive retries will show warning
|
||||||
|
# we do that here, so if we have problems authenticating, we see them immediately
|
||||||
|
# notice: this is across the board warning omission
|
||||||
|
urllib_log_warning_setup(total_retries=http_retries_config.get('total', 0), display_warning_after=3)
|
||||||
|
|
||||||
def _send_request(
|
def _send_request(
|
||||||
self,
|
self,
|
||||||
service,
|
service,
|
||||||
@ -394,7 +402,65 @@ class Session(TokenManager):
|
|||||||
if not config:
|
if not config:
|
||||||
from ...config import config_obj
|
from ...config import config_obj
|
||||||
config = config_obj
|
config = config_obj
|
||||||
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
|
return ENV_HOST.get(default=(config.get("api.api_server", None) or
|
||||||
|
config.get("api.host", None) or cls.default_host))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_server_host(cls, config=None):
|
||||||
|
if not config:
|
||||||
|
from ...config import config_obj
|
||||||
|
config = config_obj
|
||||||
|
|
||||||
|
# get from config/environment
|
||||||
|
web_host = ENV_WEB_HOST.get(default=config.get("api.web_server", None))
|
||||||
|
if web_host:
|
||||||
|
return web_host
|
||||||
|
|
||||||
|
# return default
|
||||||
|
host = cls.get_api_server_host(config)
|
||||||
|
if host == cls.default_host:
|
||||||
|
return cls.default_web
|
||||||
|
|
||||||
|
# compose ourselves
|
||||||
|
if '://demoapi.' in host:
|
||||||
|
return host.replace('://demoapi.', '://demoapp.', 1)
|
||||||
|
if '://api.' in host:
|
||||||
|
return host.replace('://api.', '://app.', 1)
|
||||||
|
|
||||||
|
parsed = urlparse(host)
|
||||||
|
if parsed.port == 8008:
|
||||||
|
return host.replace(':8008', ':8080', 1)
|
||||||
|
|
||||||
|
raise ValueError('Could not detect TRAINS web application server')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_files_server_host(cls, config=None):
|
||||||
|
if not config:
|
||||||
|
from ...config import config_obj
|
||||||
|
config = config_obj
|
||||||
|
# get from config/environment
|
||||||
|
files_host = ENV_FILES_HOST.get(default=(config.get("api.files_server", None)))
|
||||||
|
if files_host:
|
||||||
|
return files_host
|
||||||
|
|
||||||
|
# return default
|
||||||
|
host = cls.get_api_server_host(config)
|
||||||
|
if host == cls.default_host:
|
||||||
|
return cls.default_files
|
||||||
|
|
||||||
|
# compose ourselves
|
||||||
|
app_host = cls.get_app_server_host(config)
|
||||||
|
parsed = urlparse(app_host)
|
||||||
|
if parsed.port:
|
||||||
|
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
||||||
|
elif parsed.netloc.startswith('demoapp.'):
|
||||||
|
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
||||||
|
elif parsed.netloc.startswith('app.'):
|
||||||
|
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
||||||
|
else:
|
||||||
|
parsed = parsed._replace(netloc=parsed.netloc + ':8081')
|
||||||
|
|
||||||
|
return urlunparse(parsed)
|
||||||
|
|
||||||
def _do_refresh_token(self, old_token, exp=None):
|
def _do_refresh_token(self, old_token, exp=None):
|
||||||
""" TokenManager abstract method implementation.
|
""" TokenManager abstract method implementation.
|
||||||
|
@ -27,6 +27,29 @@ def get_config():
|
|||||||
return config_obj
|
return config_obj
|
||||||
|
|
||||||
|
|
||||||
|
def urllib_log_warning_setup(total_retries=10, display_warning_after=5):
|
||||||
|
class RetryFilter(logging.Filter):
|
||||||
|
last_instance = None
|
||||||
|
|
||||||
|
def __init__(self, total, warning_after=5):
|
||||||
|
super(RetryFilter, self).__init__()
|
||||||
|
self.total = total
|
||||||
|
self.display_warning_after = warning_after
|
||||||
|
self.last_instance = self
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
if record.args and len(record.args) > 0 and isinstance(record.args[0], Retry):
|
||||||
|
retry_left = self.total - record.args[0].total
|
||||||
|
return retry_left >= self.display_warning_after
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
urllib3_log = logging.getLogger('urllib3.connectionpool')
|
||||||
|
if urllib3_log:
|
||||||
|
urllib3_log.removeFilter(RetryFilter.last_instance)
|
||||||
|
urllib3_log.addFilter(RetryFilter(total_retries, display_warning_after))
|
||||||
|
|
||||||
|
|
||||||
class TLSv1HTTPAdapter(HTTPAdapter):
|
class TLSv1HTTPAdapter(HTTPAdapter):
|
||||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||||
self.poolmanager = PoolManager(num_pools=connections,
|
self.poolmanager = PoolManager(num_pools=connections,
|
||||||
|
@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface):
|
|||||||
if log:
|
if log:
|
||||||
log.error(error_msg)
|
log.error(error_msg)
|
||||||
|
|
||||||
if res.meta.result_code <= 500:
|
except requests.exceptions.BaseHTTPError as e:
|
||||||
|
res = None
|
||||||
|
log.error('Failed sending %s: %s' % (str(req), str(e)))
|
||||||
|
except Exception as e:
|
||||||
|
res = None
|
||||||
|
log.error('Failed sending %s: %s' % (str(req), str(e)))
|
||||||
|
|
||||||
|
if res and res.meta.result_code <= 500:
|
||||||
# Proper backend error/bad status code - raise or return
|
# Proper backend error/bad status code - raise or return
|
||||||
if raise_on_errors:
|
if raise_on_errors:
|
||||||
raise SendError(res, error_msg)
|
raise SendError(res, error_msg)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
except requests.exceptions.BaseHTTPError as e:
|
# # Infrastructure error
|
||||||
log.error('failed sending %s: %s' % (str(req), str(e)))
|
# if log:
|
||||||
|
# log.info('retrying request %s' % str(req))
|
||||||
# Infrastructure error
|
|
||||||
if log:
|
|
||||||
log.info('retrying request %s' % str(req))
|
|
||||||
|
|
||||||
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
|
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
|
||||||
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
|
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import collections
|
import collections
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import cv2
|
|
||||||
import six
|
import six
|
||||||
|
from threading import Thread, Event
|
||||||
|
|
||||||
from ..base import InterfaceBase
|
from ..base import InterfaceBase
|
||||||
from ..setupuploadmixin import SetupUploadMixin
|
from ..setupuploadmixin import SetupUploadMixin
|
||||||
@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
self._bucket_config = None
|
self._bucket_config = None
|
||||||
self._storage_uri = None
|
self._storage_uri = None
|
||||||
self._async_enable = async_enable
|
self._async_enable = async_enable
|
||||||
|
self._flush_frequency = 30.0
|
||||||
|
self._exit_flag = False
|
||||||
|
self._flush_event = Event()
|
||||||
|
self._flush_event.clear()
|
||||||
|
self._thread = Thread(target=self._daemon)
|
||||||
|
self._thread.daemon = True
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
def _set_storage_uri(self, value):
|
def _set_storage_uri(self, value):
|
||||||
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
|
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
|
||||||
@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
def async_enable(self, value):
|
def async_enable(self, value):
|
||||||
self._async_enable = bool(value)
|
self._async_enable = bool(value)
|
||||||
|
|
||||||
|
def _daemon(self):
|
||||||
|
while not self._exit_flag:
|
||||||
|
self._flush_event.wait(self._flush_frequency)
|
||||||
|
self._flush_event.clear()
|
||||||
|
self._write()
|
||||||
|
# wait for all reports
|
||||||
|
if self.get_num_results() > 0:
|
||||||
|
self.wait_for_results()
|
||||||
|
|
||||||
def _report(self, ev):
|
def _report(self, ev):
|
||||||
self._events.append(ev)
|
self._events.append(ev)
|
||||||
if len(self._events) >= self._flush_threshold:
|
if len(self._events) >= self._flush_threshold:
|
||||||
self._write()
|
self.flush()
|
||||||
|
|
||||||
def _write(self):
|
def _write(self):
|
||||||
if not self._events:
|
if not self._events:
|
||||||
@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
"""
|
"""
|
||||||
Flush cached reports to backend.
|
Flush cached reports to backend.
|
||||||
"""
|
"""
|
||||||
self._write()
|
self._flush_event.set()
|
||||||
# wait for all reports
|
|
||||||
if self.get_num_results() > 0:
|
def stop(self):
|
||||||
self.wait_for_results()
|
self._exit_flag = True
|
||||||
|
self._flush_event.set()
|
||||||
|
self._thread.join()
|
||||||
|
|
||||||
def report_scalar(self, title, series, value, iter):
|
def report_scalar(self, title, series, value, iter):
|
||||||
"""
|
"""
|
||||||
|
@ -77,6 +77,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
|
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
|
||||||
self._edit_lock = RLock()
|
self._edit_lock = RLock()
|
||||||
super(Task, self).__init__(id=task_id, session=session, log=log)
|
super(Task, self).__init__(id=task_id, session=session, log=log)
|
||||||
|
self._project_name = None
|
||||||
self._storage_uri = None
|
self._storage_uri = None
|
||||||
self._input_model = None
|
self._input_model = None
|
||||||
self._output_model = None
|
self._output_model = None
|
||||||
@ -87,6 +88,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
self._parameters_allowed_types = (
|
self._parameters_allowed_types = (
|
||||||
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
|
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
|
||||||
)
|
)
|
||||||
|
self._app_server = None
|
||||||
|
self._files_server = None
|
||||||
|
|
||||||
if not task_id:
|
if not task_id:
|
||||||
# generate a new task
|
# generate a new task
|
||||||
@ -656,8 +659,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
if self.project is None:
|
if self.project is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if self._project_name and self._project_name[0] == self.project:
|
||||||
|
return self._project_name[1]
|
||||||
|
|
||||||
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
||||||
return res.response.project.name
|
self._project_name = (self.project, res.response.project.name)
|
||||||
|
return self._project_name[1]
|
||||||
|
|
||||||
def get_tags(self):
|
def get_tags(self):
|
||||||
return self._get_task_property("tags")
|
return self._get_task_property("tags")
|
||||||
@ -668,33 +675,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
self._edit(tags=self.data.tags)
|
self._edit(tags=self.data.tags)
|
||||||
|
|
||||||
def _get_default_report_storage_uri(self):
|
def _get_default_report_storage_uri(self):
|
||||||
app_host = self._get_app_server()
|
if not self._files_server:
|
||||||
parsed = urlparse(app_host)
|
self._files_server = Session.get_files_server_host()
|
||||||
if parsed.port:
|
return self._files_server
|
||||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
|
||||||
elif parsed.netloc.startswith('demoapp.'):
|
|
||||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
|
||||||
elif parsed.netloc.startswith('app.'):
|
|
||||||
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
|
||||||
else:
|
|
||||||
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
|
||||||
return urlunparse(parsed)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_api_server(cls):
|
def _get_api_server(cls):
|
||||||
return Session.get_api_server_host()
|
return Session.get_api_server_host()
|
||||||
|
|
||||||
@classmethod
|
def _get_app_server(self):
|
||||||
def _get_app_server(cls):
|
if not self._app_server:
|
||||||
host = cls._get_api_server()
|
self._app_server = Session.get_app_server_host()
|
||||||
if '://demoapi.' in host:
|
return self._app_server
|
||||||
return host.replace('://demoapi.', '://demoapp.', 1)
|
|
||||||
if '://api.' in host:
|
|
||||||
return host.replace('://api.', '://app.', 1)
|
|
||||||
|
|
||||||
parsed = urlparse(host)
|
|
||||||
if parsed.port == 8008:
|
|
||||||
return host.replace(':8008', ':8080', 1)
|
|
||||||
|
|
||||||
def _edit(self, **kwargs):
|
def _edit(self, **kwargs):
|
||||||
with self._edit_lock:
|
with self._edit_lock:
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from ..config import TASK_LOG_ENVIRONMENT, running_remotely
|
from ..config import TASK_LOG_ENVIRONMENT, running_remotely
|
||||||
|
|
||||||
|
|
||||||
@ -34,3 +36,43 @@ class EnvironmentBind(object):
|
|||||||
if running_remotely():
|
if running_remotely():
|
||||||
# put back into os:
|
# put back into os:
|
||||||
os.environ.update(env_param)
|
os.environ.update(env_param)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchOsFork(object):
|
||||||
|
_original_fork = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_fork(cls):
|
||||||
|
# only once
|
||||||
|
if cls._original_fork:
|
||||||
|
return
|
||||||
|
if six.PY2:
|
||||||
|
cls._original_fork = staticmethod(os.fork)
|
||||||
|
else:
|
||||||
|
cls._original_fork = os.fork
|
||||||
|
os.fork = cls._patched_fork
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_fork(*args, **kwargs):
|
||||||
|
ret = PatchOsFork._original_fork(*args, **kwargs)
|
||||||
|
# Make sure the new process stdout is logged
|
||||||
|
if not ret:
|
||||||
|
from ..task import Task
|
||||||
|
if Task.current_task() is not None:
|
||||||
|
# bind sub-process logger
|
||||||
|
task = Task.init()
|
||||||
|
task.get_logger().flush()
|
||||||
|
|
||||||
|
# if we got here patch the os._exit of our instance to call us
|
||||||
|
def _at_exit_callback(*args, **kwargs):
|
||||||
|
# call at exit manually
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
task._at_exit()
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return os._org_exit(*args, **kwargs)
|
||||||
|
|
||||||
|
if not hasattr(os, '_org_exit'):
|
||||||
|
os._org_exit = os._exit
|
||||||
|
os._exit = _at_exit_callback
|
||||||
|
|
||||||
|
return ret
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -44,9 +45,17 @@ class EventTrainsWriter(object):
|
|||||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||||
Trains events and reports the events (metrics) for an Trains task (logger).
|
Trains events and reports the events (metrics) for an Trains task (logger).
|
||||||
"""
|
"""
|
||||||
_add_lock = threading.Lock()
|
_add_lock = threading.RLock()
|
||||||
_series_name_lookup = {}
|
_series_name_lookup = {}
|
||||||
|
|
||||||
|
# store all the created tensorboard writers in the system
|
||||||
|
# this allows us to as weather a certain tile/series already exist on some EventWriter
|
||||||
|
# and if it does, then we add to the series name the last token from the logdir
|
||||||
|
# (so we can differentiate between the two)
|
||||||
|
# key, value: key=hash(title, graph), value=EventTrainsWriter._id
|
||||||
|
_title_series_writers_lookup = {}
|
||||||
|
_event_writers_id_to_logdir = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variants(self):
|
def variants(self):
|
||||||
return self._variants
|
return self._variants
|
||||||
@ -54,8 +63,8 @@ class EventTrainsWriter(object):
|
|||||||
def prepare_report(self):
|
def prepare_report(self):
|
||||||
return self.variants.copy()
|
return self.variants.copy()
|
||||||
|
|
||||||
@staticmethod
|
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
|
||||||
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
logdir_header='series'):
|
||||||
"""
|
"""
|
||||||
Split a tf.summary tag line to variant and metric.
|
Split a tf.summary tag line to variant and metric.
|
||||||
Variant is the first part of the split tag, metric is the second.
|
Variant is the first part of the split tag, metric is the second.
|
||||||
@ -64,15 +73,64 @@ class EventTrainsWriter(object):
|
|||||||
:param str split_char: a character to split the tag on
|
:param str split_char: a character to split the tag on
|
||||||
:param str join_char: a character to join the the splits
|
:param str join_char: a character to join the the splits
|
||||||
:param str default_title: variant to use in case no variant can be inferred automatically
|
:param str default_title: variant to use in case no variant can be inferred automatically
|
||||||
|
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
|
||||||
|
if 'title_last' then title=header title, if 'title' then title=title header
|
||||||
:return: (str, str) variant and metric
|
:return: (str, str) variant and metric
|
||||||
"""
|
"""
|
||||||
splitted_tag = tag.split(split_char)
|
splitted_tag = tag.split(split_char)
|
||||||
series = join_char.join(splitted_tag[-num_split_parts:])
|
series = join_char.join(splitted_tag[-num_split_parts:])
|
||||||
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
||||||
|
|
||||||
|
# check if we already decided that we need to change the title/series
|
||||||
|
graph_id = hash((title, series))
|
||||||
|
if graph_id in self._graph_name_lookup:
|
||||||
|
return self._graph_name_lookup[graph_id]
|
||||||
|
|
||||||
|
# check if someone other than us used this combination
|
||||||
|
with self._add_lock:
|
||||||
|
event_writer_id = self._title_series_writers_lookup.get(graph_id, None)
|
||||||
|
if not event_writer_id:
|
||||||
|
# put us there
|
||||||
|
self._title_series_writers_lookup[graph_id] = self._id
|
||||||
|
elif event_writer_id != self._id:
|
||||||
|
# if there is someone else, change our series name and store us
|
||||||
|
org_series = series
|
||||||
|
org_title = title
|
||||||
|
other_logdir = self._event_writers_id_to_logdir[event_writer_id]
|
||||||
|
split_logddir = self._logdir.split(os.path.sep)
|
||||||
|
unique_logdir = set(split_logddir) - set(other_logdir.split(os.path.sep))
|
||||||
|
header = '/'.join(s for s in split_logddir if s in unique_logdir)
|
||||||
|
if logdir_header == 'series_last':
|
||||||
|
series = header + ': ' + series
|
||||||
|
elif logdir_header == 'series':
|
||||||
|
series = series + ' :' + header
|
||||||
|
elif logdir_header == 'title':
|
||||||
|
title = title + ' ' + header
|
||||||
|
else: # logdir_header == 'title_last':
|
||||||
|
title = header + ' ' + title
|
||||||
|
graph_id = hash((title, series))
|
||||||
|
# check if for some reason the new series is already occupied
|
||||||
|
new_event_writer_id = self._title_series_writers_lookup.get(graph_id)
|
||||||
|
if new_event_writer_id is not None and new_event_writer_id != self._id:
|
||||||
|
# well that's about it, nothing else we could do
|
||||||
|
if logdir_header == 'series_last':
|
||||||
|
series = str(self._logdir) + ': ' + org_series
|
||||||
|
elif logdir_header == 'series':
|
||||||
|
series = org_series + ' :' + str(self._logdir)
|
||||||
|
elif logdir_header == 'title':
|
||||||
|
title = org_title + ' ' + str(self._logdir)
|
||||||
|
else: # logdir_header == 'title_last':
|
||||||
|
title = str(self._logdir) + ' ' + org_title
|
||||||
|
graph_id = hash((title, series))
|
||||||
|
|
||||||
|
self._title_series_writers_lookup[graph_id] = self._id
|
||||||
|
|
||||||
|
# store for next time
|
||||||
|
self._graph_name_lookup[graph_id] = (title, series)
|
||||||
return title, series
|
return title, series
|
||||||
|
|
||||||
def __init__(self, logger, report_freq=100, image_report_freq=None, histogram_update_freq_multiplier=10,
|
def __init__(self, logger, logdir=None, report_freq=100, image_report_freq=None,
|
||||||
histogram_granularity=50, max_keep_images=None):
|
histogram_update_freq_multiplier=10, histogram_granularity=50, max_keep_images=None):
|
||||||
"""
|
"""
|
||||||
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
|
Create a compatible Trains backend to the TensorFlow SummaryToEventTransformer
|
||||||
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
|
Everything will be serialized directly to the Trains backend, instead of to the standard TF FileWriter
|
||||||
@ -87,6 +145,9 @@ class EventTrainsWriter(object):
|
|||||||
"""
|
"""
|
||||||
# We are the events_writer, so that's what we'll pass
|
# We are the events_writer, so that's what we'll pass
|
||||||
IsTensorboardInit.set_tensorboard_used()
|
IsTensorboardInit.set_tensorboard_used()
|
||||||
|
self._logdir = logdir or ('unknown %d' % len(self._event_writers_id_to_logdir))
|
||||||
|
self._id = hash(self._logdir)
|
||||||
|
self._event_writers_id_to_logdir[self._id] = self._logdir
|
||||||
self.max_keep_images = max_keep_images
|
self.max_keep_images = max_keep_images
|
||||||
self.report_freq = report_freq
|
self.report_freq = report_freq
|
||||||
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
self.image_report_freq = image_report_freq if image_report_freq else report_freq
|
||||||
@ -99,6 +160,7 @@ class EventTrainsWriter(object):
|
|||||||
self._hist_report_cache = {}
|
self._hist_report_cache = {}
|
||||||
self._hist_x_granularity = 50
|
self._hist_x_granularity = 50
|
||||||
self._max_step = 0
|
self._max_step = 0
|
||||||
|
self._graph_name_lookup = {}
|
||||||
|
|
||||||
def _decode_image(self, img_str, width, height, color_channels):
|
def _decode_image(self, img_str, width, height, color_channels):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -131,7 +193,7 @@ class EventTrainsWriter(object):
|
|||||||
if img_data_np is None:
|
if img_data_np is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images')
|
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
|
||||||
if img_data_np.dtype != np.uint8:
|
if img_data_np.dtype != np.uint8:
|
||||||
# assume scale 0-1
|
# assume scale 0-1
|
||||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||||
@ -168,7 +230,7 @@ class EventTrainsWriter(object):
|
|||||||
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
|
return self._add_image_numpy(tag=tag, step=step, img_data_np=matrix)
|
||||||
|
|
||||||
def _add_scalar(self, tag, step, scalar_data):
|
def _add_scalar(self, tag, step, scalar_data):
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars')
|
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Scalars', logdir_header='series_last')
|
||||||
|
|
||||||
# update scalar cache
|
# update scalar cache
|
||||||
num, value = self._scalar_report_cache.get((title, series), (0, 0))
|
num, value = self._scalar_report_cache.get((title, series), (0, 0))
|
||||||
@ -216,7 +278,8 @@ class EventTrainsWriter(object):
|
|||||||
# Y-axis (rows) is iteration (from 0 to current Step)
|
# Y-axis (rows) is iteration (from 0 to current Step)
|
||||||
# X-axis averaged bins (conformed sample 'bucketLimit')
|
# X-axis averaged bins (conformed sample 'bucketLimit')
|
||||||
# Z-axis actual value (interpolated 'bucket')
|
# Z-axis actual value (interpolated 'bucket')
|
||||||
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms')
|
title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
|
||||||
|
logdir_header='series')
|
||||||
|
|
||||||
# get histograms from cache
|
# get histograms from cache
|
||||||
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
hist_list, hist_iters, minmax = self._hist_report_cache.get((title, series), ([], np.array([]), None))
|
||||||
@ -570,8 +633,12 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||||
if not self.trains:
|
if not self.trains:
|
||||||
|
try:
|
||||||
|
logdir = self.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
**PatchSummaryToEventTransformer.defaults_dict)
|
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
self.trains.add_event(*args, **kwargs)
|
self.trains.add_event(*args, **kwargs)
|
||||||
@ -584,8 +651,12 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||||
if not self.trains:
|
if not self.trains:
|
||||||
|
try:
|
||||||
|
logdir = self.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
**PatchSummaryToEventTransformer.defaults_dict)
|
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
self.trains.add_event(*args, **kwargs)
|
self.trains.add_event(*args, **kwargs)
|
||||||
@ -617,8 +688,13 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
|
|
||||||
# patch the events writer field, and add a double Event Logger (Trains and original)
|
# patch the events writer field, and add a double Event Logger (Trains and original)
|
||||||
base_eventwriter = __dict__['event_writer']
|
base_eventwriter = __dict__['event_writer']
|
||||||
|
try:
|
||||||
|
logdir = base_eventwriter.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
||||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), **defaults_dict)
|
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
|
logdir=logdir, **defaults_dict)
|
||||||
|
|
||||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||||
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
|
__dict__['event_writer'] = ProxyEventsWriter([trains_event, base_eventwriter])
|
||||||
@ -798,12 +874,17 @@ class PatchTensorFlowEager(object):
|
|||||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_event_writer():
|
def _get_event_writer(writer):
|
||||||
if not PatchTensorFlowEager.__main_task:
|
if not PatchTensorFlowEager.__main_task:
|
||||||
return None
|
return None
|
||||||
if PatchTensorFlowEager.__trains_event_writer is None:
|
if PatchTensorFlowEager.__trains_event_writer is None:
|
||||||
|
try:
|
||||||
|
logdir = writer.get_logdir()
|
||||||
|
except Exception:
|
||||||
|
logdir = None
|
||||||
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
PatchTensorFlowEager.__trains_event_writer = EventTrainsWriter(
|
||||||
logger=PatchTensorFlowEager.__main_task.get_logger(), **PatchTensorFlowEager.defaults_dict)
|
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||||
|
**PatchTensorFlowEager.defaults_dict)
|
||||||
return PatchTensorFlowEager.__trains_event_writer
|
return PatchTensorFlowEager.__trains_event_writer
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -812,7 +893,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
||||||
@ -822,7 +903,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
||||||
@ -832,7 +913,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer()
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer:
|
if event_writer:
|
||||||
try:
|
try:
|
||||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
||||||
@ -1350,93 +1431,3 @@ class PatchTensorflowModelIO(object):
|
|||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class PatchPyTorchModelIO(object):
|
|
||||||
__main_task = None
|
|
||||||
__patched = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_current_task(task, **kwargs):
|
|
||||||
PatchPyTorchModelIO.__main_task = task
|
|
||||||
PatchPyTorchModelIO._patch_model_io()
|
|
||||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _patch_model_io():
|
|
||||||
if PatchPyTorchModelIO.__patched:
|
|
||||||
return
|
|
||||||
|
|
||||||
if 'torch' not in sys.modules:
|
|
||||||
return
|
|
||||||
|
|
||||||
PatchPyTorchModelIO.__patched = True
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
# hack: make sure tensorflow.__init__ is called
|
|
||||||
import torch
|
|
||||||
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
|
||||||
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass # print('Failed patching pytorch')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
if isinstance(f, six.string_types):
|
|
||||||
filename = f
|
|
||||||
elif hasattr(f, 'name'):
|
|
||||||
filename = f.name
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
f.flush()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
filename = None
|
|
||||||
|
|
||||||
# give the model a descriptive name based on the file name
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
model_name = Path(filename).stem
|
|
||||||
except Exception:
|
|
||||||
model_name = None
|
|
||||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
|
||||||
singlefile=True, model_name=model_name)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
|
||||||
if isinstance(f, six.string_types):
|
|
||||||
filename = f
|
|
||||||
elif hasattr(f, 'name'):
|
|
||||||
filename = f.name
|
|
||||||
else:
|
|
||||||
filename = None
|
|
||||||
|
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
|
||||||
return original_fn(f, *args, **kwargs)
|
|
||||||
|
|
||||||
# register input model
|
|
||||||
empty = _Empty()
|
|
||||||
if running_remotely():
|
|
||||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
|
||||||
PatchPyTorchModelIO.__main_task)
|
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
# try to load model before registering, in case we fail
|
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
|
||||||
PatchPyTorchModelIO.__main_task)
|
|
||||||
|
|
||||||
if empty.trains_in_model:
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
model.trains_in_model = empty.trains_in_model
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return model
|
|
||||||
|
@ -11,19 +11,20 @@ from trains.config import config_obj
|
|||||||
|
|
||||||
|
|
||||||
description = """
|
description = """
|
||||||
Please create new credentials using the web app: {}/admin
|
Please create new credentials using the web app: {}/profile
|
||||||
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
In the Admin page, press "Create new credentials", then press "Copy to clipboard"
|
||||||
|
|
||||||
Paste credentials here: """
|
Paste credentials here: """
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def_host = ENV_HOST.get(default=config_obj.get("api.host"))
|
def_host = ENV_HOST.get(default=config_obj.get("api.web_server")) or 'http://localhost:8080'
|
||||||
except Exception:
|
except Exception:
|
||||||
def_host = 'http://localhost:8080'
|
def_host = 'http://localhost:8080'
|
||||||
|
|
||||||
host_description = """
|
host_description = """
|
||||||
Editing configuration file: {CONFIG_FILE}
|
Editing configuration file: {CONFIG_FILE}
|
||||||
Enter the url of the trains-server's api service, for example: http://localhost:8008 : """.format(
|
Enter the url of the trains-server's Web service, for example: {HOST}
|
||||||
|
""".format(
|
||||||
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
|
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
|
||||||
HOST=def_host,
|
HOST=def_host,
|
||||||
)
|
)
|
||||||
@ -37,64 +38,60 @@ def main():
|
|||||||
print('Leaving setup, feel free to edit the configuration file.')
|
print('Leaving setup, feel free to edit the configuration file.')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(host_description, end='')
|
print(host_description)
|
||||||
parsed_host = None
|
web_host = input_url('Web Application Host', '')
|
||||||
while not parsed_host:
|
parsed_host = verify_url(web_host)
|
||||||
parse_input = input()
|
|
||||||
if not parse_input:
|
|
||||||
parse_input = def_host
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
|
|
||||||
parse_input = 'http://'+parse_input
|
|
||||||
parsed_host = urlparse(parse_input)
|
|
||||||
if parsed_host.scheme not in ('http', 'https'):
|
|
||||||
parsed_host = None
|
|
||||||
except Exception:
|
|
||||||
parsed_host = None
|
|
||||||
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
|
|
||||||
|
|
||||||
if parsed_host.port == 8080:
|
if parsed_host.port == 8008:
|
||||||
# this is a docker 8080 is the web address, we need the api address, it is 8008
|
print('Port 8008 is the api port. Replacing 8080 with 8008 for Web application')
|
||||||
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8081', 1) + parsed_host.path
|
||||||
|
elif parsed_host.port == 8080:
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('demoapp.'):
|
elif parsed_host.netloc.startswith('demoapp.'):
|
||||||
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
|
|
||||||
parsed_host.netloc))
|
|
||||||
# this is our demo server
|
# this is our demo server
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demofiles.', 1) + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('app.'):
|
elif parsed_host.netloc.startswith('app.'):
|
||||||
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
|
|
||||||
parsed_host.netloc))
|
|
||||||
# this is our application server
|
# this is our application server
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
elif parsed_host.port == 8008:
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'files.', 1) + parsed_host.path
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
|
||||||
elif parsed_host.netloc.startswith('demoapi.'):
|
elif parsed_host.netloc.startswith('demoapi.'):
|
||||||
|
print('{} is the api server, we need the web server. Replacing \'demoapi.\' with \'demoapp.\''.format(
|
||||||
|
parsed_host.netloc))
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demofiles.', 1) + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('api.'):
|
elif parsed_host.netloc.startswith('api.'):
|
||||||
|
print('{} is the api server, we need the web server. Replacing \'api.\' with \'app.\''.format(
|
||||||
|
parsed_host.netloc))
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'files.', 1) + parsed_host.path
|
||||||
else:
|
else:
|
||||||
api_host = None
|
api_host = ''
|
||||||
web_host = None
|
web_host = ''
|
||||||
|
files_host = ''
|
||||||
if not parsed_host.port:
|
if not parsed_host.port:
|
||||||
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
|
print('Host port not detected, do you wish to use the default 8008 port n/[y]? ', end='')
|
||||||
replace_port = input().lower()
|
replace_port = input().lower()
|
||||||
if not replace_port or replace_port == 'y' or replace_port == 'yes':
|
if not replace_port or replace_port == 'y' or replace_port == 'yes':
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8080' + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8081' + parsed_host.path
|
||||||
if not api_host:
|
if not api_host:
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
if not web_host:
|
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
|
||||||
|
|
||||||
print('Host configured to: {}'.format(api_host))
|
api_host = input_url('API Host', api_host)
|
||||||
|
files_host = input_url('File Store Host', files_host)
|
||||||
|
|
||||||
|
print('\nTRAINS Hosts configuration:\nAPI: {}\nWeb App: {}\nFile Store: {}\n'.format(
|
||||||
|
api_host, web_host, files_host))
|
||||||
|
|
||||||
print(description.format(web_host), end='')
|
print(description.format(web_host), end='')
|
||||||
parse_input = input()
|
parse_input = input()
|
||||||
@ -133,11 +130,14 @@ def main():
|
|||||||
header = '# TRAINS SDK configuration file\n' \
|
header = '# TRAINS SDK configuration file\n' \
|
||||||
'api {\n' \
|
'api {\n' \
|
||||||
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
|
' # Notice: \'host\' is the api server (default port 8008), not the web server.\n' \
|
||||||
' host: %s\n' \
|
' api_server: %s\n' \
|
||||||
' # Credentials are generated in the webapp, %s/admin\n' \
|
' web_server: %s\n' \
|
||||||
|
' files_server: %s\n' \
|
||||||
|
' # Credentials are generated in the webapp, %s/profile\n' \
|
||||||
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
|
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
|
||||||
'}\n' \
|
'}\n' \
|
||||||
'sdk ' % (api_host, web_host, credentials['access_key'], credentials['secret_key'])
|
'sdk ' % (api_host, web_host, files_host,
|
||||||
|
web_host, credentials['access_key'], credentials['secret_key'])
|
||||||
f.write(header)
|
f.write(header)
|
||||||
f.write(default_sdk)
|
f.write(default_sdk)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -148,5 +148,30 @@ def main():
|
|||||||
print('TRAINS setup completed successfully.')
|
print('TRAINS setup completed successfully.')
|
||||||
|
|
||||||
|
|
||||||
|
def input_url(host_type, host=None):
|
||||||
|
while True:
|
||||||
|
print('{} configured to: [{}] '.format(host_type, host), end='')
|
||||||
|
parse_input = input()
|
||||||
|
if host and (not parse_input or parse_input.lower() == 'yes' or parse_input.lower() == 'y'):
|
||||||
|
break
|
||||||
|
if parse_input and verify_url(parse_input):
|
||||||
|
host = parse_input
|
||||||
|
break
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
def verify_url(parse_input):
|
||||||
|
try:
|
||||||
|
if not parse_input.startswith('http://') and not parse_input.startswith('https://'):
|
||||||
|
parse_input = 'http://' + parse_input
|
||||||
|
parsed_host = urlparse(parse_input)
|
||||||
|
if parsed_host.scheme not in ('http', 'https'):
|
||||||
|
parsed_host = None
|
||||||
|
except Exception:
|
||||||
|
parsed_host = None
|
||||||
|
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
|
||||||
|
return parsed_host
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@ -81,10 +81,13 @@ class Logger(object):
|
|||||||
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
|
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
if Logger._stdout_original_write is None:
|
||||||
Logger._stdout_original_write = sys.stdout.write
|
Logger._stdout_original_write = sys.stdout.write
|
||||||
# this will only work in python 3, guard it with try/catch
|
# this will only work in python 3, guard it with try/catch
|
||||||
|
if not hasattr(sys.stdout, '_original_write'):
|
||||||
sys.stdout._original_write = sys.stdout.write
|
sys.stdout._original_write = sys.stdout.write
|
||||||
sys.stdout.write = stdout__patched__write__
|
sys.stdout.write = stdout__patched__write__
|
||||||
|
if not hasattr(sys.stderr, '_original_write'):
|
||||||
sys.stderr._original_write = sys.stderr.write
|
sys.stderr._original_write = sys.stderr.write
|
||||||
sys.stderr.write = stderr__patched__write__
|
sys.stderr.write = stderr__patched__write__
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -113,6 +116,7 @@ class Logger(object):
|
|||||||
msg='Logger failed casting log level "%s" to integer' % str(level))
|
msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
record = self._task.log.makeRecord(
|
record = self._task.log.makeRecord(
|
||||||
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
||||||
@ -128,6 +132,7 @@ class Logger(object):
|
|||||||
if not omit_console:
|
if not omit_console:
|
||||||
# if we are here and we grabbed the stdout, we need to print the real thing
|
# if we are here and we grabbed the stdout, we need to print the real thing
|
||||||
if DevWorker.report_stdout:
|
if DevWorker.report_stdout:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we are writing to the original stdout
|
# make sure we are writing to the original stdout
|
||||||
Logger._stdout_original_write(str(msg)+'\n')
|
Logger._stdout_original_write(str(msg)+'\n')
|
||||||
@ -637,11 +642,13 @@ class Logger(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _remove_std_logger(self):
|
def _remove_std_logger(self):
|
||||||
if isinstance(sys.stdout, PrintPatchLogger):
|
if isinstance(sys.stdout, PrintPatchLogger):
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
sys.stdout.connect(None)
|
sys.stdout.connect(None)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if isinstance(sys.stderr, PrintPatchLogger):
|
if isinstance(sys.stderr, PrintPatchLogger):
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
sys.stderr.connect(None)
|
sys.stderr.connect(None)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -711,7 +718,13 @@ class PrintPatchLogger(object):
|
|||||||
|
|
||||||
if cur_line:
|
if cur_line:
|
||||||
with PrintPatchLogger.recursion_protect_lock:
|
with PrintPatchLogger.recursion_protect_lock:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if self._log:
|
||||||
self._log.console(cur_line, level=self._log_level, omit_console=True)
|
self._log.console(cur_line, level=self._log_level, omit_console=True)
|
||||||
|
except Exception:
|
||||||
|
# what can we do, nothing
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
if hasattr(self._terminal, '_original_write'):
|
if hasattr(self._terminal, '_original_write'):
|
||||||
self._terminal._original_write(message)
|
self._terminal._original_write(message)
|
||||||
@ -719,8 +732,7 @@ class PrintPatchLogger(object):
|
|||||||
self._terminal.write(message)
|
self._terminal.write(message)
|
||||||
|
|
||||||
def connect(self, logger):
|
def connect(self, logger):
|
||||||
if self._log:
|
self._cur_line = ''
|
||||||
self._log._flush_stdout_handler()
|
|
||||||
self._log = logger
|
self._log = logger
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
|
102
trains/task.py
102
trains/task.py
@ -26,7 +26,7 @@ from .errors import UsageError
|
|||||||
from .logger import Logger
|
from .logger import Logger
|
||||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||||
from .task_parameters import TaskParameters
|
from .task_parameters import TaskParameters
|
||||||
from .binding.environ_bind import EnvironmentBind
|
from .binding.environ_bind import EnvironmentBind, PatchOsFork
|
||||||
from .binding.absl_bind import PatchAbsl
|
from .binding.absl_bind import PatchAbsl
|
||||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||||
argparser_update_currenttask
|
argparser_update_currenttask
|
||||||
@ -66,6 +66,7 @@ class Task(_Task):
|
|||||||
__create_protection = object()
|
__create_protection = object()
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__exit_hook = None
|
__exit_hook = None
|
||||||
|
__forked_proc_main_pid = None
|
||||||
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
||||||
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
|
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
|
||||||
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
||||||
@ -104,7 +105,6 @@ class Task(_Task):
|
|||||||
self._resource_monitor = None
|
self._resource_monitor = None
|
||||||
# register atexit, so that we mark the task as stopped
|
# register atexit, so that we mark the task as stopped
|
||||||
self._at_exit_called = False
|
self._at_exit_called = False
|
||||||
self.__register_at_exit(self._at_exit)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def current_task(cls):
|
def current_task(cls):
|
||||||
@ -132,9 +132,10 @@ class Task(_Task):
|
|||||||
:param project_name: project to create the task in (if project doesn't exist, it will be created)
|
:param project_name: project to create the task in (if project doesn't exist, it will be created)
|
||||||
:param task_name: task name to be created (in development mode, not when running remotely)
|
:param task_name: task name to be created (in development mode, not when running remotely)
|
||||||
:param task_type: task type to be created (in development mode, not when running remotely)
|
:param task_type: task type to be created (in development mode, not when running remotely)
|
||||||
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \
|
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
|
||||||
if False every time we call the function we create a new task with the same name \
|
if False every time we call the function we create a new task with the same name
|
||||||
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \
|
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
|
||||||
|
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
|
||||||
Note: A closed or published task will not be reused, and a new task will be created.
|
Note: A closed or published task will not be reused, and a new task will be created.
|
||||||
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
|
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
|
||||||
notice: sub-folders (task_id) is created in the destination folder for all outputs.
|
notice: sub-folders (task_id) is created in the destination folder for all outputs.
|
||||||
@ -166,12 +167,31 @@ class Task(_Task):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cls.__main_task is not None:
|
if cls.__main_task is not None:
|
||||||
|
# if this is a subprocess, regardless of what the init was called for,
|
||||||
|
# we have to fix the main task hooks and stdout bindings
|
||||||
|
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
||||||
|
# make sure we only do it once per process
|
||||||
|
cls.__forked_proc_main_pid = os.getpid()
|
||||||
|
# make sure we do not wait for the repo detect thread
|
||||||
|
cls.__main_task._detect_repo_async_thread = None
|
||||||
|
# remove the logger from the previous process
|
||||||
|
logger = cls.__main_task.get_logger()
|
||||||
|
logger.set_flush_period(None)
|
||||||
|
# create a new logger (to catch stdout/err)
|
||||||
|
cls.__main_task._logger = None
|
||||||
|
cls.__main_task._reporter = None
|
||||||
|
cls.__main_task.get_logger()
|
||||||
|
# unregister signal hooks, they cause subprocess to hang
|
||||||
|
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
|
||||||
|
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
|
||||||
|
|
||||||
if not running_remotely():
|
if not running_remotely():
|
||||||
verify_defaults_match()
|
verify_defaults_match()
|
||||||
|
|
||||||
return cls.__main_task
|
return cls.__main_task
|
||||||
|
|
||||||
# check that we are not a child process, in that case do nothing
|
# check that we are not a child process, in that case do nothing.
|
||||||
|
# we should not get here unless this is Windows platform, all others support fork
|
||||||
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
||||||
class _TaskStub(object):
|
class _TaskStub(object):
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
@ -212,9 +232,10 @@ class Task(_Task):
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
Task.__main_task = task
|
Task.__main_task = task
|
||||||
# Patch argparse to be aware of the current task
|
# register the main task for at exit hooks (there should only be one)
|
||||||
argparser_update_currenttask(Task.__main_task)
|
task.__register_at_exit(task._at_exit)
|
||||||
EnvironmentBind.update_current_task(Task.__main_task)
|
# patch OS forking
|
||||||
|
PatchOsFork.patch_fork()
|
||||||
if auto_connect_frameworks:
|
if auto_connect_frameworks:
|
||||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
@ -227,21 +248,19 @@ class Task(_Task):
|
|||||||
if auto_resource_monitoring:
|
if auto_resource_monitoring:
|
||||||
task._resource_monitor = ResourceMonitor(task)
|
task._resource_monitor = ResourceMonitor(task)
|
||||||
task._resource_monitor.start()
|
task._resource_monitor.start()
|
||||||
# Check if parse args already called. If so, sync task parameters with parser
|
|
||||||
if argparser_parseargs_called():
|
|
||||||
parser, parsed_args = get_argparser_last_args()
|
|
||||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
|
||||||
|
|
||||||
# make sure all random generators are initialized with new seed
|
# make sure all random generators are initialized with new seed
|
||||||
make_deterministic(task.get_random_seed())
|
make_deterministic(task.get_random_seed())
|
||||||
|
|
||||||
if auto_connect_arg_parser:
|
if auto_connect_arg_parser:
|
||||||
|
EnvironmentBind.update_current_task(Task.__main_task)
|
||||||
|
|
||||||
# Patch ArgParser to be aware of the current task
|
# Patch ArgParser to be aware of the current task
|
||||||
argparser_update_currenttask(Task.__main_task)
|
argparser_update_currenttask(Task.__main_task)
|
||||||
# Check if parse args already called. If so, sync task parameters with parser
|
# Check if parse args already called. If so, sync task parameters with parser
|
||||||
if argparser_parseargs_called():
|
if argparser_parseargs_called():
|
||||||
parser, parsed_args = get_argparser_last_args()
|
parser, parsed_args = get_argparser_last_args()
|
||||||
task._connect_argparse(parser, parsed_args=parsed_args)
|
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||||
|
|
||||||
# Make sure we start the logger, it will patch the main logging object and pipe all output
|
# Make sure we start the logger, it will patch the main logging object and pipe all output
|
||||||
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
|
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
|
||||||
@ -339,7 +358,9 @@ class Task(_Task):
|
|||||||
in_dev_mode = not running_remotely()
|
in_dev_mode = not running_remotely()
|
||||||
|
|
||||||
if in_dev_mode:
|
if in_dev_mode:
|
||||||
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
|
||||||
|
default_task_id = reuse_last_task_id
|
||||||
|
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||||
default_task_id = None
|
default_task_id = None
|
||||||
closed_old_task = cls.__close_timed_out_task(default_task)
|
closed_old_task = cls.__close_timed_out_task(default_task)
|
||||||
else:
|
else:
|
||||||
@ -600,6 +621,9 @@ class Task(_Task):
|
|||||||
"""
|
"""
|
||||||
self._at_exit()
|
self._at_exit()
|
||||||
self._at_exit_called = False
|
self._at_exit_called = False
|
||||||
|
# unregister atexit callbacks and signal hooks, if we are the main task
|
||||||
|
if self.is_main_task():
|
||||||
|
self.__register_at_exit(None)
|
||||||
|
|
||||||
def is_current_task(self):
|
def is_current_task(self):
|
||||||
"""
|
"""
|
||||||
@ -914,9 +938,12 @@ class Task(_Task):
|
|||||||
Will happen automatically once we exit code, i.e. atexit
|
Will happen automatically once we exit code, i.e. atexit
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
# protect sub-process at_exit
|
||||||
if self._at_exit_called:
|
if self._at_exit_called:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# from here do not get into watch dog
|
# from here do not get into watch dog
|
||||||
@ -948,11 +975,14 @@ class Task(_Task):
|
|||||||
# from here, do not send log in background thread
|
# from here, do not send log in background thread
|
||||||
if wait_for_uploads:
|
if wait_for_uploads:
|
||||||
self.flush(wait_for_uploads=True)
|
self.flush(wait_for_uploads=True)
|
||||||
|
# wait until the reporter flush everything
|
||||||
|
self.reporter.stop()
|
||||||
if print_done_waiting:
|
if print_done_waiting:
|
||||||
self.log.info('Finished uploading')
|
self.log.info('Finished uploading')
|
||||||
else:
|
else:
|
||||||
self._logger._flush_stdout_handler()
|
self._logger._flush_stdout_handler()
|
||||||
|
|
||||||
|
if not is_sub_process:
|
||||||
# from here, do not check worker status
|
# from here, do not check worker status
|
||||||
if self._dev_worker:
|
if self._dev_worker:
|
||||||
self._dev_worker.unregister()
|
self._dev_worker.unregister()
|
||||||
@ -970,6 +1000,7 @@ class Task(_Task):
|
|||||||
# stop resource monitoring
|
# stop resource monitoring
|
||||||
if self._resource_monitor:
|
if self._resource_monitor:
|
||||||
self._resource_monitor.stop()
|
self._resource_monitor.stop()
|
||||||
|
|
||||||
self._logger.set_flush_period(None)
|
self._logger.set_flush_period(None)
|
||||||
# this is so in theory we can close a main task and start a new one
|
# this is so in theory we can close a main task and start a new one
|
||||||
Task.__main_task = None
|
Task.__main_task = None
|
||||||
@ -978,7 +1009,7 @@ class Task(_Task):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __register_at_exit(cls, exit_callback):
|
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
|
||||||
class ExitHooks(object):
|
class ExitHooks(object):
|
||||||
_orig_exit = None
|
_orig_exit = None
|
||||||
_orig_exc_handler = None
|
_orig_exc_handler = None
|
||||||
@ -1000,7 +1031,21 @@ class Task(_Task):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._exit_callback = callback
|
self._exit_callback = callback
|
||||||
atexit.register(self._exit_callback)
|
if callback:
|
||||||
|
self.hook()
|
||||||
|
else:
|
||||||
|
# un register int hook
|
||||||
|
print('removing int hook', self._orig_exc_handler)
|
||||||
|
if self._orig_exc_handler:
|
||||||
|
sys.excepthook = self._orig_exc_handler
|
||||||
|
self._orig_exc_handler = None
|
||||||
|
for s in self._org_handlers:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
signal.signal(s, self._org_handlers[s])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._org_handlers = {}
|
||||||
|
|
||||||
def hook(self):
|
def hook(self):
|
||||||
if self._orig_exit is None:
|
if self._orig_exit is None:
|
||||||
@ -1009,7 +1054,10 @@ class Task(_Task):
|
|||||||
if self._orig_exc_handler is None:
|
if self._orig_exc_handler is None:
|
||||||
self._orig_exc_handler = sys.excepthook
|
self._orig_exc_handler = sys.excepthook
|
||||||
sys.excepthook = self.exc_handler
|
sys.excepthook = self.exc_handler
|
||||||
|
if self._exit_callback:
|
||||||
atexit.register(self._exit_callback)
|
atexit.register(self._exit_callback)
|
||||||
|
|
||||||
|
if self._org_handlers:
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||||
signal.SIGILL, signal.SIGFPE]
|
signal.SIGILL, signal.SIGFPE]
|
||||||
@ -1077,6 +1125,22 @@ class Task(_Task):
|
|||||||
# return handler result
|
# return handler result
|
||||||
return org_handler
|
return org_handler
|
||||||
|
|
||||||
|
# we only remove the signals since this will hang subprocesses
|
||||||
|
if only_remove_signal_and_exception_hooks:
|
||||||
|
if not cls.__exit_hook:
|
||||||
|
return
|
||||||
|
if cls.__exit_hook._orig_exc_handler:
|
||||||
|
sys.excepthook = cls.__exit_hook._orig_exc_handler
|
||||||
|
cls.__exit_hook._orig_exc_handler = None
|
||||||
|
for s in cls.__exit_hook._org_handlers:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
signal.signal(s, cls.__exit_hook._org_handlers[s])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cls.__exit_hook._org_handlers = {}
|
||||||
|
return
|
||||||
|
|
||||||
if cls.__exit_hook is None:
|
if cls.__exit_hook is None:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1084,13 +1148,13 @@ class Task(_Task):
|
|||||||
cls.__exit_hook.hook()
|
cls.__exit_hook.hook()
|
||||||
except Exception:
|
except Exception:
|
||||||
cls.__exit_hook = None
|
cls.__exit_hook = None
|
||||||
elif cls.__main_task is None:
|
else:
|
||||||
cls.__exit_hook.update_callback(exit_callback)
|
cls.__exit_hook.update_callback(exit_callback)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_task(cls, task_id=None, project_name=None, task_name=None):
|
def __get_task(cls, task_id=None, project_name=None, task_name=None):
|
||||||
if task_id:
|
if task_id:
|
||||||
return cls(private=cls.__create_protection, task_id=task_id)
|
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
|
||||||
|
|
||||||
res = cls._send(
|
res = cls._send(
|
||||||
cls._get_default_session(),
|
cls._get_default_session(),
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
@ -6,7 +7,8 @@ import six
|
|||||||
|
|
||||||
class AsyncManagerMixin(object):
|
class AsyncManagerMixin(object):
|
||||||
_async_results_lock = Lock()
|
_async_results_lock = Lock()
|
||||||
_async_results = []
|
# per pid (process) list of async jobs (support for sub-processes forking)
|
||||||
|
_async_results = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None):
|
def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None):
|
||||||
@ -14,8 +16,9 @@ class AsyncManagerMixin(object):
|
|||||||
try:
|
try:
|
||||||
cls._async_results_lock.acquire()
|
cls._async_results_lock.acquire()
|
||||||
# discard completed results
|
# discard completed results
|
||||||
cls._async_results = [r for r in cls._async_results if not r.ready()]
|
pid = os.getpid()
|
||||||
num_results = len(cls._async_results)
|
cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()]
|
||||||
|
num_results = len(cls._async_results[pid])
|
||||||
if wait_on_max_results is not None and num_results >= wait_on_max_results:
|
if wait_on_max_results is not None and num_results >= wait_on_max_results:
|
||||||
# At least max_results results are still pending, wait
|
# At least max_results results are still pending, wait
|
||||||
if wait_cb:
|
if wait_cb:
|
||||||
@ -25,7 +28,9 @@ class AsyncManagerMixin(object):
|
|||||||
continue
|
continue
|
||||||
# add result
|
# add result
|
||||||
if result and not result.ready():
|
if result and not result.ready():
|
||||||
cls._async_results.append(result)
|
if not cls._async_results.get(pid):
|
||||||
|
cls._async_results[pid] = []
|
||||||
|
cls._async_results[pid].append(result)
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
cls._async_results_lock.release()
|
cls._async_results_lock.release()
|
||||||
@ -34,7 +39,8 @@ class AsyncManagerMixin(object):
|
|||||||
def wait_for_results(cls, timeout=None, max_num_uploads=None):
|
def wait_for_results(cls, timeout=None, max_num_uploads=None):
|
||||||
remaining = timeout
|
remaining = timeout
|
||||||
count = 0
|
count = 0
|
||||||
for r in cls._async_results:
|
pid = os.getpid()
|
||||||
|
for r in cls._async_results.get(pid, []):
|
||||||
if r.ready():
|
if r.ready():
|
||||||
continue
|
continue
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@ -48,13 +54,14 @@ class AsyncManagerMixin(object):
|
|||||||
if max_num_uploads is not None and max_num_uploads - count <= 0:
|
if max_num_uploads is not None and max_num_uploads - count <= 0:
|
||||||
break
|
break
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
remaining = max(0, remaining - max(0, time.time() - t))
|
remaining = max(0., remaining - max(0., time.time() - t))
|
||||||
if not remaining:
|
if not remaining:
|
||||||
break
|
break
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_num_results(cls):
|
def get_num_results(cls):
|
||||||
if cls._async_results is not None:
|
pid = os.getpid()
|
||||||
return len([r for r in cls._async_results if not r.ready()])
|
if cls._async_results.get(pid, []):
|
||||||
|
return len([r for r in cls._async_results.get(pid, []) if not r.ready()])
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.10.2'
|
__version__ = '0.10.3rc1'
|
||||||
|
Loading…
Reference in New Issue
Block a user