Fix Hydra support, relative path argv[0]. (issue #219)

This commit is contained in:
allegroai 2020-11-01 02:27:27 +02:00
parent 49b578b979
commit 6e7cb6b6f1

View File

@ -569,22 +569,24 @@ class ScriptInfo(object):
return Path(entry_point).as_posix()
@classmethod
def _get_working_dir(cls, repo_root, return_abs=False):
repo_root = Path(repo_root).absolute()
cwd = None
def _cwd(cls):
# return the current working directory (solve for hydra changing it)
# check if running with hydra
if sys.modules.get('hydra'):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
import hydra
cwd = Path(hydra.utils.get_original_cwd()).absolute()
return Path(hydra.utils.get_original_cwd()).absolute()
except Exception:
pass
return Path.cwd().absolute()
if not cwd:
cwd = Path.cwd().absolute()
@classmethod
def _get_working_dir(cls, repo_root, return_abs=False):
# get the repository working directory (might be different from actual cwd)
repo_root = Path(repo_root).absolute()
cwd = cls._cwd()
try:
# do not change: test if we are under the repo root folder, it will throw an exception if we are not
@ -594,6 +596,15 @@ class ScriptInfo(object):
# Working directory not under repository root, default to repo root
return repo_root.as_posix() if return_abs else '.'
@classmethod
def _absolute_path(cls, file_path, cwd):
# return the absolute path, relative to a specific working directory (cwd)
file_path = Path(file_path)
if file_path.is_absolute():
return file_path.as_posix()
# Convert to absolute and squash 'path/../folder'
return os.path.abspath((Path(cwd).absolute() / file_path).as_posix())
@classmethod
def _get_script_code(cls, script_path):
# noinspection PyBroadException
@ -612,7 +623,8 @@ class ScriptInfo(object):
if jupyter_filepath:
scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
else:
scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f]
cwd = cls._cwd()
scripts_path = [Path(cls._absolute_path(os.path.normpath(f), cwd)) for f in filepaths if f]
if all(not f.is_file() for f in scripts_path):
raise ScriptInfoError(
"Script file {} could not be found".format(scripts_path)