From 49b578b979be762ef6523ddcccbe73d30d836889 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 30 Oct 2020 10:04:03 +0200 Subject: [PATCH] Add Hydra support phase one: fix current working dir (issue #219). Fix cwd outside of repository root folder --- .../backend_interface/task/repo/scriptinfo.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index d9376038..36d095f7 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -560,7 +560,8 @@ class ScriptInfo(object): try: # Use os.path.relpath as it calculates up dir movements (../) - entry_point = os.path.relpath(str(script_path), str(Path.cwd())) + entry_point = os.path.relpath( + str(script_path), str(cls._get_working_dir(repo_root, return_abs=True))) except ValueError: # Working directory not under repository root entry_point = script_path.relative_to(repo_root) @@ -568,14 +569,30 @@ class ScriptInfo(object): return Path(entry_point).as_posix() @classmethod - def _get_working_dir(cls, repo_root): + def _get_working_dir(cls, repo_root, return_abs=False): repo_root = Path(repo_root).absolute() + cwd = None + + # check if running with hydra + if sys.modules.get('hydra'): + # noinspection PyBroadException + try: + # noinspection PyPackageRequirements + import hydra + cwd = Path(hydra.utils.get_original_cwd()).absolute() + except Exception: + pass + + if not cwd: + cwd = Path.cwd().absolute() try: - return Path.cwd().relative_to(repo_root).as_posix() + # do not change: test if we are under the repo root folder, it will throw an exception if we are not + relative = cwd.relative_to(repo_root).as_posix() + return cwd.as_posix() if return_abs else relative except ValueError: - # Working directory not under repository root - return os.path.curdir + # Working directory not under repository root, default to repo root + return repo_root.as_posix() if return_abs else '.' @classmethod def _get_script_code(cls, script_path):