From 6e7cb6b6f12e1725251b4c8bc00ac665c7841db4 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 1 Nov 2020 02:27:27 +0200 Subject: [PATCH] Fix Hydra support, relative path argv[0]. (issue #219) --- .../backend_interface/task/repo/scriptinfo.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 36d095f7..b802a4bb 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -569,22 +569,24 @@ class ScriptInfo(object): return Path(entry_point).as_posix() @classmethod - def _get_working_dir(cls, repo_root, return_abs=False): - repo_root = Path(repo_root).absolute() - cwd = None - + def _cwd(cls): + # return the current working directory (solve for hydra changing it) # check if running with hydra if sys.modules.get('hydra'): # noinspection PyBroadException try: # noinspection PyPackageRequirements import hydra - cwd = Path(hydra.utils.get_original_cwd()).absolute() + return Path(hydra.utils.get_original_cwd()).absolute() except Exception: pass + return Path.cwd().absolute() - if not cwd: - cwd = Path.cwd().absolute() + @classmethod + def _get_working_dir(cls, repo_root, return_abs=False): + # get the repository working directory (might be different from actual cwd) + repo_root = Path(repo_root).absolute() + cwd = cls._cwd() try: # do not change: test if we are under the repo root folder, it will throw an exception if we are not @@ -594,6 +596,15 @@ class ScriptInfo(object): # Working directory not under repository root, default to repo root return repo_root.as_posix() if return_abs else '.' + @classmethod + def _absolute_path(cls, file_path, cwd): + # return the absolute path, relative to a specific working directory (cwd) + file_path = Path(file_path) + if file_path.is_absolute(): + return file_path.as_posix() + # Convert to absolute and squash 'path/../folder' + return os.path.abspath((Path(cwd).absolute() / file_path).as_posix()) + @classmethod def _get_script_code(cls, script_path): # noinspection PyBroadException @@ -612,7 +623,8 @@ class ScriptInfo(object): if jupyter_filepath: scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()] else: - scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f] + cwd = cls._cwd() + scripts_path = [Path(cls._absolute_path(os.path.normpath(f), cwd)) for f in filepaths if f] if all(not f.is_file() for f in scripts_path): raise ScriptInfoError( "Script file {} could not be found".format(scripts_path)