From d2d3e595af2eb606797f833afe1315c6aecd55e2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 19 Jun 2020 20:51:46 +0300 Subject: [PATCH] Fix incorrect entry point detection when called from Trains wrapper (such as TrainsLogger Ignite/Lightning) --- trains/backend_interface/task/repo/scriptinfo.py | 10 +++++++++- trains/backend_interface/task/task.py | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 0b2ab265..5f4b4876 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -673,6 +673,14 @@ class ScriptInfo(object): log.warning("Failed auto-detecting task repository: {}".format(ex)) return ScriptInfoResult(), None + @classmethod + def is_running_from_module(cls): + # noinspection PyBroadException + try: + return '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__'] + except Exception: + return False + @classmethod def detect_running_module(cls, script_dict): # noinspection PyBroadException @@ -681,7 +689,7 @@ class ScriptInfo(object): if script_dict.get('jupyter_filepath'): return script_dict - if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']: + if cls.is_running_from_module(): argvs = '' git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None for a in sys.argv[1:]: diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index a8c4d6b0..67ce4f44 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -257,7 +257,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): check_package_update_thread.start() # do not request requirements, because it might be a long process, and we first want to update the git repo result, script_requirements = ScriptInfo.get( - filepaths=[self._calling_filename, sys.argv[0], ], + filepaths=[self._calling_filename, sys.argv[0], ] + if ScriptInfo.is_running_from_module() else [sys.argv[0], self._calling_filename, ], log=self.log, create_requirements=False, check_uncommitted=self._store_diff ) for msg in result.warning_messages: @@ -266,7 +267,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # store original entry point entry_point = result.script.get('entry_point') if result.script else None - # check if we are running inside a module, then we should set our entrypoint + # check if we are running inside a module, then we should set our entry point # to the module call including all argv's result.script = ScriptInfo.detect_running_module(result.script)