From 1d6727d2c0eb6135f6160d789489ae5938ac0857 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 7 Aug 2019 00:05:51 +0300 Subject: [PATCH] Improve auto package detection --- .../backend_interface/task/repo/scriptinfo.py | 28 ++++++++++++++++++- trains/binding/frameworks/pytorch_bind.py | 2 +- trains/task.py | 7 +++-- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 65684c53..6a0da3a9 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -24,6 +24,28 @@ class ScriptRequirements(object): def __init__(self, root_folder): self._root_folder = root_folder + @staticmethod + def get_installed_pkgs_detail(reqs): + """ + HACK: bugfix of the original pigar get_installed_pkgs_detail + + Get mapping for import top level name + and install package name with version. + """ + mapping = dict() + + for path in sys.path: + if os.path.isdir(path) and path.rstrip('/').endswith( + ('site-packages', 'dist-packages')): + new_mapping = reqs._search_path(path) + # BUGFIX: + # override with previous, just like python resolves imports, the first match is the one used. + # unlike the original implementation, where the last one is used. + new_mapping.update(mapping) + mapping = new_mapping + + return mapping + def get_requirements(self): try: from pigar import reqs @@ -31,7 +53,11 @@ class ScriptRequirements(object): from pigar.__main__ import GenerateReqs from pigar.log import logger logger.setLevel(logging.WARNING) - installed_pkgs = reqs.get_installed_pkgs_detail() + try: + # first try our version, if we fail revert to the internal implantation + installed_pkgs = self.get_installed_pkgs_detail(reqs) + except Exception: + installed_pkgs = reqs.get_installed_pkgs_detail() gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs, ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints']) reqs, try_imports, guess = gr.extract_reqs() diff --git a/trains/binding/frameworks/pytorch_bind.py b/trains/binding/frameworks/pytorch_bind.py index 592667cd..a3354f0f 100644 --- a/trains/binding/frameworks/pytorch_bind.py +++ b/trains/binding/frameworks/pytorch_bind.py @@ -3,7 +3,7 @@ import sys import six from pathlib2 import Path -from trains.binding.frameworks.base_bind import PatchBaseModelIO +from ...binding.frameworks.base_bind import PatchBaseModelIO from ..frameworks import _patched_call, WeightsFileHandler, _Empty from ..import_bind import PostImportHookPatching from ...config import running_remotely diff --git a/trains/task.py b/trains/task.py index 98b91eb7..1bc8bd29 100644 --- a/trains/task.py +++ b/trains/task.py @@ -989,13 +989,14 @@ class Task(_Task): # check if we crashed, ot the signal is not interrupt (manual break) task_status = ('stopped', ) if self.__exit_hook: - if self.__exit_hook.exception is not None or \ - (not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal not in (None, 2)): + if (self.__exit_hook.exception and not isinstance(self.__exit_hook.exception, KeyboardInterrupt)) \ + or (not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal not in (None, 2)): task_status = ('failed', 'Exception') wait_for_uploads = False else: wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None) - if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None: + if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None and \ + not self.__exit_hook.exception: task_status = ('completed', ) else: task_status = ('stopped', )