mirror of
https://github.com/clearml/clearml
synced 2025-02-07 21:33:25 +00:00
Fix incorrect entry point detection when called from Trains wrapper (such as TrainsLogger Ignite/Lightning)
This commit is contained in:
parent
5fbfa1d6e2
commit
d2d3e595af
@ -673,6 +673,14 @@ class ScriptInfo(object):
|
|||||||
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
||||||
return ScriptInfoResult(), None
|
return ScriptInfoResult(), None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_running_from_module(cls):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
return '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def detect_running_module(cls, script_dict):
|
def detect_running_module(cls, script_dict):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -681,7 +689,7 @@ class ScriptInfo(object):
|
|||||||
if script_dict.get('jupyter_filepath'):
|
if script_dict.get('jupyter_filepath'):
|
||||||
return script_dict
|
return script_dict
|
||||||
|
|
||||||
if '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']:
|
if cls.is_running_from_module():
|
||||||
argvs = ''
|
argvs = ''
|
||||||
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
|
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
|
||||||
for a in sys.argv[1:]:
|
for a in sys.argv[1:]:
|
||||||
|
@ -257,7 +257,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
check_package_update_thread.start()
|
check_package_update_thread.start()
|
||||||
# do not request requirements, because it might be a long process, and we first want to update the git repo
|
# do not request requirements, because it might be a long process, and we first want to update the git repo
|
||||||
result, script_requirements = ScriptInfo.get(
|
result, script_requirements = ScriptInfo.get(
|
||||||
filepaths=[self._calling_filename, sys.argv[0], ],
|
filepaths=[self._calling_filename, sys.argv[0], ]
|
||||||
|
if ScriptInfo.is_running_from_module() else [sys.argv[0], self._calling_filename, ],
|
||||||
log=self.log, create_requirements=False, check_uncommitted=self._store_diff
|
log=self.log, create_requirements=False, check_uncommitted=self._store_diff
|
||||||
)
|
)
|
||||||
for msg in result.warning_messages:
|
for msg in result.warning_messages:
|
||||||
@ -266,7 +267,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
# store original entry point
|
# store original entry point
|
||||||
entry_point = result.script.get('entry_point') if result.script else None
|
entry_point = result.script.get('entry_point') if result.script else None
|
||||||
|
|
||||||
# check if we are running inside a module, then we should set our entrypoint
|
# check if we are running inside a module, then we should set our entry point
|
||||||
# to the module call including all argv's
|
# to the module call including all argv's
|
||||||
result.script = ScriptInfo.detect_running_module(result.script)
|
result.script = ScriptInfo.detect_running_module(result.script)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user