Fix incorrect entry point detection when called from Trains wrapper (such as TrainsLogger Ignite/Lightning)

This commit is contained in:
allegroai 2020-06-19 20:51:46 +03:00
parent 5fbfa1d6e2
commit d2d3e595af
2 changed files with 12 additions and 3 deletions

View File

@ -673,6 +673,14 @@ class ScriptInfo(object):
log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult(), None
@classmethod
def is_running_from_module(cls):
# noinspection PyBroadException
try:
return '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']
except Exception:
return False
@classmethod
def detect_running_module(cls, script_dict):
# noinspection PyBroadException
@ -681,7 +689,7 @@ class ScriptInfo(object):
if script_dict.get('jupyter_filepath'):
return script_dict
if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']:
if cls.is_running_from_module():
argvs = ''
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
for a in sys.argv[1:]:

View File

@ -257,7 +257,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
check_package_update_thread.start()
# do not request requirements, because it might be a long process, and we first want to update the git repo
result, script_requirements = ScriptInfo.get(
filepaths=[self._calling_filename, sys.argv[0], ],
filepaths=[self._calling_filename, sys.argv[0], ]
if ScriptInfo.is_running_from_module() else [sys.argv[0], self._calling_filename, ],
log=self.log, create_requirements=False, check_uncommitted=self._store_diff
)
for msg in result.warning_messages:
@ -266,7 +267,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# store original entry point
entry_point = result.script.get('entry_point') if result.script else None
# check if we are running inside a module, then we should set our entrypoint
# check if we are running inside a module, then we should set our entry point
# to the module call including all argv's
result.script = ScriptInfo.detect_running_module(result.script)