Fix python 2.7 and Windows support

This commit is contained in:
allegroai 2019-10-10 21:10:51 +03:00
parent c1bcce9692
commit e0e6d9159b
4 changed files with 81 additions and 51 deletions

View File

@ -2,6 +2,7 @@
import collections import collections
import itertools import itertools
import logging import logging
import os
from enum import Enum from enum import Enum
from threading import Thread from threading import Thread
from multiprocessing import RLock from multiprocessing import RLock
@ -189,9 +190,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True) latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version: if latest_version:
if not latest_version[1]: if not latest_version[1]:
sep = os.linesep
self.get_logger().report_text( self.get_logger().report_text(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format( 'TRAINS new package available: UPGRADE to v{} is recommended! '
latest_version[0]), '{}'.format(
latest_version[0], sep.join(latest_version[2])),
) )
else: else:
self.get_logger().report_text( self.get_logger().report_text(

View File

@ -1,7 +1,23 @@
import getpass import getpass
import re import re
from _socket import gethostname from _socket import gethostname
from datetime import datetime, timezone from datetime import datetime
try:
from datetime import timezone
utc_timezone = timezone.utc
except ImportError:
from datetime import tzinfo, timedelta
class UTC(tzinfo):
def utcoffset(self, dt):
return timedelta(0)
def tzname(self, dt):
return "UTC"
def dst(self, dt):
return timedelta(0)
utc_timezone = UTC()
from ..backend_api.services import projects from ..backend_api.services import projects
from ..debugging.log import get_logger from ..debugging.log import get_logger
@ -26,8 +42,8 @@ def get_or_create_project(session, project_name, description=None):
# Hack for supporting windows # Hack for supporting windows
def get_epoch_beginning_of_time(tzinfo=None): def get_epoch_beginning_of_time(timezone_info=None):
return datetime(1970, 1, 1, tzinfo=tzinfo if tzinfo else timezone.utc) return datetime(1970, 1, 1).replace(tzinfo=timezone_info if timezone_info else utc_timezone)
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True): def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True):

View File

@ -41,8 +41,6 @@ from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
from .utilities.dicts import ReadOnlyDict from .utilities.dicts import ReadOnlyDict
NotSet = object()
class Task(_Task): class Task(_Task):
""" """
@ -67,6 +65,8 @@ class Task(_Task):
TaskTypes = _Task.TaskTypes TaskTypes = _Task.TaskTypes
NotSet = object()
__create_protection = object() __create_protection = object()
__main_task = None __main_task = None
__exit_hook = None __exit_hook = None
@ -566,37 +566,15 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def get_logger(self, flush_period=NotSet): def get_logger(self):
# type: (Optional[float]) -> Logger # type: () -> Logger
""" """
get a logger object for reporting based on the task get a logger object for reporting, for this task context.
All reports (metrics, text etc.) related to this task are accessible in the web UI
:param flush_period: The period of the logger flush.
If None of any other False value, will not flush periodically.
If a logger was created before, this will be the new period and
the old one will be discarded.
:return: Logger object :return: Logger object
""" """
if not self._logger: return self._get_logger()
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self.reporter.async_enable = True
# if we just created the logger, set default flush period
if not flush_period or flush_period is NotSet:
flush_period = DevWorker.report_period
if isinstance(flush_period, (int, float)):
flush_period = int(abs(flush_period))
if flush_period is None or isinstance(flush_period, int):
self._logger.set_flush_period(flush_period)
return self._logger
def mark_started(self): def mark_started(self):
""" """
@ -819,6 +797,40 @@ class Task(_Task):
if secret: if secret:
Session.default_secret = secret Session.default_secret = secret
def _get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
"""
get a logger object for reporting based on the task
:param flush_period: The period of the logger flush.
If None of any other False value, will not flush periodically.
If a logger was created before, this will be the new period and
the old one will be discarded.
:return: Logger object
"""
pass
if not self._logger:
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self.reporter.async_enable = True
# if we just created the logger, set default flush period
if not flush_period or flush_period is self.NotSet:
flush_period = DevWorker.report_period
if isinstance(flush_period, (int, float)):
flush_period = int(abs(flush_period))
if flush_period is None or isinstance(flush_period, int):
self._logger.set_flush_period(flush_period)
return self._logger
def _connect_output_model(self, model): def _connect_output_model(self, model):
assert isinstance(model, OutputModel) assert isinstance(model, OutputModel)
model.connect(self) model.connect(self)

View File

@ -316,24 +316,21 @@ class CheckPackageUpdates(object):
try: try:
from ..version import __version__ from ..version import __version__
cls._package_version_checked = True cls._package_version_checked = True
# Sending the request only for statistics
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server,
args=(__version__,))
update_statistics.daemon = True
update_statistics.start()
releases = requests.get('https://pypi.python.org/pypi/trains/json', timeout=3.0).json()['releases'].keys()
releases = [Version(r) for r in releases]
latest_version = sorted(releases)
cur_version = Version(__version__) cur_version = Version(__version__)
if not cur_version.is_devrelease and not cur_version.is_prerelease: update_server_releases = requests.get('https://updates.trainsai.io/updates',
latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease] data=json.dumps({"versions": {"trains": str(cur_version)}}),
timeout=3.0)
if cur_version >= latest_version[-1]: if update_server_releases.ok:
update_server_releases = update_server_releases.json()
else:
return None return None
not_patch_upgrade = latest_version[-1].release[:2] != cur_version.release[:2] trains_answer = update_server_releases.get("trains", {})
return str(latest_version[-1]), not_patch_upgrade latest_version = Version(trains_answer.get("version"))
if cur_version >= latest_version:
return None
not_patch_upgrade = latest_version.release[:2] == cur_version.release[:2]
return str(latest_version), not_patch_upgrade, trains_answer.get("description").split("\r\n")
except Exception: except Exception:
return None return None
@ -341,6 +338,8 @@ class CheckPackageUpdates(object):
def get_version_from_updates_server(cur_version): def get_version_from_updates_server(cur_version):
try: try:
_ = requests.get('https://updates.trainsai.io/updates', _ = requests.get('https://updates.trainsai.io/updates',
params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0) data=json.dumps({"versions": {"trains": str(cur_version)}}),
timeout=1.0)
return
except Exception: except Exception:
pass pass