From 1bfee56977e5b0a1dcfd89733e130b9c0e87d651 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 8 Nov 2019 22:28:13 +0200 Subject: [PATCH] Improve Windows support --- trains/backend_interface/metrics/events.py | 15 ++++++++++++++- trains/backend_interface/task/args.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/trains/backend_interface/metrics/events.py b/trains/backend_interface/metrics/events.py index c7aac3e8..848403f2 100644 --- a/trains/backend_interface/metrics/events.py +++ b/trains/backend_interface/metrics/events.py @@ -87,6 +87,9 @@ class MetricsEventAdapter(object): """ Get information for a file that should be uploaded before this event is sent """ pass + def get_iteration(self): + return self._iter + def update(self, task=None, **kwargs): """ Update event properties """ if task: @@ -175,6 +178,10 @@ class UploadEvent(MetricsEventAdapter): _metric_counters_lock = Lock() _image_file_history_size = int(config.get('metrics.file_history_size', 5)) + @staticmethod + def _replace_slash(part): + return part.replace('\\', '/').strip('/').replace('/', '.slash.') + def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None, image_file_history_size=None, delete_after_upload=False, **kwargs): # param override_filename: override uploaded file name (notice extension will be added from local path @@ -194,6 +201,11 @@ class UploadEvent(MetricsEventAdapter): self._filename = '%s_%s_%08d' % (metric, variant, self._count) else: self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size) + + # make sure we have to '/' in the filename because it might access other folders, + # and we don't want that to occur + self._filename = self._replace_slash(self._filename) + self._upload_uri = upload_uri self._delete_after_upload = delete_after_upload @@ -288,7 +300,8 @@ class UploadEvent(MetricsEventAdapter): filename = self._upload_filename if self._override_storage_key_prefix or not storage_key_prefix: storage_key_prefix = self._override_storage_key_prefix - key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x) + key = '/'.join(x for x in (storage_key_prefix, self._replace_slash(self.metric), + self._replace_slash(self.variant), self._replace_slash(filename)) if x) url = '/'.join(x.strip('/') for x in (e_storage_uri, key)) # make sure we preserve local path root if e_storage_uri.startswith('/'): diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index 8307a2c4..2fdbdca2 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -211,6 +211,22 @@ class _Arguments(object): task_arguments[k] = v except Exception: pass + elif current_action and current_action.type == bool: + # parser.set_defaults cannot cast string `False`/`True` to boolean properly, + # so we have to do it manually here + strip_v = str(v).lower().strip() + if strip_v == 'false' or not strip_v: + v = False + elif strip_v == 'true': + v = True + else: + # else, try to cast to integer + try: + v = int(strip_v) + except ValueError: + pass + task_arguments[k] = v + # add as default try: if current_action and isinstance(current_action, _SubParsersAction):