Fix python 2.7 and Windows support

This commit is contained in:
allegroai 2019-10-10 21:10:51 +03:00
parent c1bcce9692
commit e0e6d9159b
4 changed files with 81 additions and 51 deletions

View File

@ -2,6 +2,7 @@
import collections
import itertools
import logging
import os
from enum import Enum
from threading import Thread
from multiprocessing import RLock
@ -189,9 +190,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
sep = os.linesep
self.get_logger().report_text(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
'TRAINS new package available: UPGRADE to v{} is recommended! '
'{}'.format(
latest_version[0], sep.join(latest_version[2])),
)
else:
self.get_logger().report_text(

View File

@ -1,7 +1,23 @@
import getpass
import re
from _socket import gethostname
from datetime import datetime, timezone
from datetime import datetime
try:
from datetime import timezone
utc_timezone = timezone.utc
except ImportError:
from datetime import tzinfo, timedelta
class UTC(tzinfo):
def utcoffset(self, dt):
return timedelta(0)
def tzname(self, dt):
return "UTC"
def dst(self, dt):
return timedelta(0)
utc_timezone = UTC()
from ..backend_api.services import projects
from ..debugging.log import get_logger
@ -26,8 +42,8 @@ def get_or_create_project(session, project_name, description=None):
# Hack for supporting windows
def get_epoch_beginning_of_time(tzinfo=None):
return datetime(1970, 1, 1, tzinfo=tzinfo if tzinfo else timezone.utc)
def get_epoch_beginning_of_time(timezone_info=None):
return datetime(1970, 1, 1).replace(tzinfo=timezone_info if timezone_info else utc_timezone)
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True):

View File

@ -41,8 +41,6 @@ from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
from .utilities.dicts import ReadOnlyDict
NotSet = object()
class Task(_Task):
"""
@ -67,6 +65,8 @@ class Task(_Task):
TaskTypes = _Task.TaskTypes
NotSet = object()
__create_protection = object()
__main_task = None
__exit_hook = None
@ -566,37 +566,15 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
def get_logger(self):
# type: () -> Logger
"""
get a logger object for reporting based on the task
:param flush_period: The period of the logger flush.
If None of any other False value, will not flush periodically.
If a logger was created before, this will be the new period and
the old one will be discarded.
get a logger object for reporting, for this task context.
All reports (metrics, text etc.) related to this task are accessible in the web UI
:return: Logger object
"""
if not self._logger:
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self.reporter.async_enable = True
# if we just created the logger, set default flush period
if not flush_period or flush_period is NotSet:
flush_period = DevWorker.report_period
if isinstance(flush_period, (int, float)):
flush_period = int(abs(flush_period))
if flush_period is None or isinstance(flush_period, int):
self._logger.set_flush_period(flush_period)
return self._logger
return self._get_logger()
def mark_started(self):
"""
@ -819,6 +797,40 @@ class Task(_Task):
if secret:
Session.default_secret = secret
def _get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
"""
get a logger object for reporting based on the task
:param flush_period: The period of the logger flush.
If None of any other False value, will not flush periodically.
If a logger was created before, this will be the new period and
the old one will be discarded.
:return: Logger object
"""
pass
if not self._logger:
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self.reporter.async_enable = True
# if we just created the logger, set default flush period
if not flush_period or flush_period is self.NotSet:
flush_period = DevWorker.report_period
if isinstance(flush_period, (int, float)):
flush_period = int(abs(flush_period))
if flush_period is None or isinstance(flush_period, int):
self._logger.set_flush_period(flush_period)
return self._logger
def _connect_output_model(self, model):
assert isinstance(model, OutputModel)
model.connect(self)

View File

@ -316,24 +316,21 @@ class CheckPackageUpdates(object):
try:
from ..version import __version__
cls._package_version_checked = True
# Sending the request only for statistics
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server,
args=(__version__,))
update_statistics.daemon = True
update_statistics.start()
releases = requests.get('https://pypi.python.org/pypi/trains/json', timeout=3.0).json()['releases'].keys()
releases = [Version(r) for r in releases]
latest_version = sorted(releases)
cur_version = Version(__version__)
if not cur_version.is_devrelease and not cur_version.is_prerelease:
latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease]
if cur_version >= latest_version[-1]:
update_server_releases = requests.get('https://updates.trainsai.io/updates',
data=json.dumps({"versions": {"trains": str(cur_version)}}),
timeout=3.0)
if update_server_releases.ok:
update_server_releases = update_server_releases.json()
else:
return None
not_patch_upgrade = latest_version[-1].release[:2] != cur_version.release[:2]
return str(latest_version[-1]), not_patch_upgrade
trains_answer = update_server_releases.get("trains", {})
latest_version = Version(trains_answer.get("version"))
if cur_version >= latest_version:
return None
not_patch_upgrade = latest_version.release[:2] == cur_version.release[:2]
return str(latest_version), not_patch_upgrade, trains_answer.get("description").split("\r\n")
except Exception:
return None
@ -341,6 +338,8 @@ class CheckPackageUpdates(object):
def get_version_from_updates_server(cur_version):
try:
_ = requests.get('https://updates.trainsai.io/updates',
params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0)
data=json.dumps({"versions": {"trains": str(cur_version)}}),
timeout=1.0)
return
except Exception:
pass