Add sdk.development.store_code_diff_from_remote (default False) to store diff from remote HEAD instead of local HEAD (issue #222)

This commit is contained in:
allegroai 2020-10-30 09:55:54 +02:00
parent a0ec4b895b
commit 753b3ff68c
4 changed files with 56 additions and 13 deletions

View File

@ -147,6 +147,7 @@ sdk {
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff: true
store_code_diff_from_remote: false
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: true

View File

@ -51,6 +51,7 @@ class Detector(object):
"""
_fallback = '_fallback'
_remote = '_remote'
@attr.s
class Commands(object):
@ -66,6 +67,10 @@ class Detector(object):
# alternative commands
branch_fallback = attr.ib(default=None, type=list)
diff_fallback = attr.ib(default=None, type=list)
# remote commands
commit_remote = attr.ib(default=None, type=list)
diff_remote = attr.ib(default=None, type=list)
diff_fallback_remote = attr.ib(default=None, type=list)
def __init__(self, type_name, name=None):
self.type_name = type_name
@ -75,14 +80,14 @@ class Detector(object):
""" Returns a RepoInfo instance containing a command for each info attribute """
return self.Commands()
def _get_command_output(self, path, name, command, strip=True):
def _get_command_output(self, path, name, command, commands=None, strip=True):
""" Run a command and return its output """
try:
return get_command_output(command, path, strip=strip)
except (CalledProcessError, UnicodeDecodeError) as ex:
if not name.endswith(self._fallback):
fallback_command = attr.asdict(self._get_commands()).get(name + self._fallback)
fallback_command = attr.asdict(commands or self._get_commands()).get(name + self._fallback)
if fallback_command:
try:
return get_command_output(fallback_command, path, strip=strip)
@ -97,11 +102,12 @@ class Detector(object):
)
return ""
def _get_info(self, path, include_diff=False):
def _get_info(self, path, include_diff=False, diff_from_remote=False):
"""
Get repository information.
:param path: Path to repository
:param include_diff: Whether to include the diff command's output (if available)
:param diff_from_remote: Whether to store the remote diff/commit based on the remote commit (not local commit)
:return: RepoInfo instance
"""
path = str(path)
@ -109,28 +115,56 @@ class Detector(object):
if not include_diff:
commands.diff = None
# skip the local commands
if diff_from_remote and commands:
for name, command in attr.asdict(commands).items():
if name.endswith(self._remote) and command:
setattr(commands, name[:-len(self._remote)], None)
info = Result(
**{
name: self._get_command_output(path, name, command, strip=bool(name != 'diff'))
name: self._get_command_output(path, name, command, commands=commands, strip=bool(name != 'diff'))
for name, command in attr.asdict(commands).items()
if command and not name.endswith(self._fallback)
if command and not name.endswith(self._fallback) and not name.endswith(self._remote)
}
)
if diff_from_remote and commands:
for name, command in attr.asdict(commands).items():
if name.endswith(self._remote) and command:
setattr(commands, name[:-len(self._remote)], command+[info.branch])
info = attr.assoc(
info,
**{
name[:-len(self._remote)]: self._get_command_output(
path, name[:-len(self._remote)], command + [info.branch],
commands=commands, strip=name.startswith('diff'))
for name, command in attr.asdict(commands).items()
if command and (
name.endswith(self._remote) and
not name[:-len(self._remote)].endswith(self._fallback)
)
}
)
# make sure we match the modified with the git remote diff state
info.modified = bool(info.diff)
return info
def _post_process_info(self, info):
# check if there are uncommitted changes in the current repository
return info
def get_info(self, path, include_diff=False):
def get_info(self, path, include_diff=False, diff_from_remote=False):
"""
Get repository information.
:param path: Path to repository
:param include_diff: Whether to include the diff command's output (if available)
:param diff_from_remote: Whether to store the remote diff/commit based on the remote commit (not local commit)
:return: RepoInfo instance
"""
info = self._get_info(path, include_diff)
info = self._get_info(path, include_diff, diff_from_remote=diff_from_remote)
return self._post_process_info(info)
def _is_repo_type(self, script_path):
@ -200,6 +234,9 @@ class GitDetector(Detector):
modified=["git", "ls-files", "-m"],
branch_fallback=["git", "rev-parse", "--abbrev-ref", "HEAD"],
diff_fallback=["git", "diff"],
diff_remote=["git", "diff", "--submodule=diff", ],
commit_remote=["git", "rev-parse", ],
diff_fallback_remote=["git", "diff", ],
)
def _post_process_info(self, info):
@ -234,7 +271,7 @@ class EnvDetector(Detector):
except Exception:
return Path.cwd()
def _get_info(self, _, include_diff=False):
def _get_info(self, _, include_diff=False, diff_from_remote=None):
repository_url = VCS_REPOSITORY_URL.get()
if not repository_url:

View File

@ -589,7 +589,8 @@ class ScriptInfo(object):
return ''
@classmethod
def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None):
def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None,
uncommitted_from_remote=False):
jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath:
scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
@ -623,7 +624,8 @@ class ScriptInfo(object):
else:
try:
for i, d in enumerate(scripts_dir):
repo_info = plugin.get_info(str(d), include_diff=check_uncommitted)
repo_info = plugin.get_info(
str(d), include_diff=check_uncommitted, diff_from_remote=uncommitted_from_remote)
if not repo_info.is_empty():
script_dir = d
script_path = scripts_path[i]
@ -697,13 +699,14 @@ class ScriptInfo(object):
script_requirements)
@classmethod
def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None):
def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None,
uncommitted_from_remote=False):
try:
if not filepaths:
filepaths = [sys.argv[0], ]
return cls._get_script_info(
filepaths=filepaths, check_uncommitted=check_uncommitted,
create_requirements=create_requirements, log=log)
create_requirements=create_requirements, log=log, uncommitted_from_remote=uncommitted_from_remote)
except Exception as ex:
if log:
log.warning("Failed auto-detecting task repository: {}".format(ex))

View File

@ -65,6 +65,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_force_requirements = {}
_store_diff = config.get('development.store_uncommitted_code_diff', False)
_store_remote_diff = config.get('development.store_code_diff_from_remote', False)
_offline_filename = 'task.json'
class TaskTypes(Enum):
@ -276,7 +277,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
result, script_requirements = ScriptInfo.get(
filepaths=[self._calling_filename, sys.argv[0], ]
if ScriptInfo.is_running_from_module() else [sys.argv[0], self._calling_filename, ],
log=self.log, create_requirements=False, check_uncommitted=self._store_diff
log=self.log, create_requirements=False,
check_uncommitted=self._store_diff, uncommitted_from_remote=self._store_remote_diff
)
for msg in result.warning_messages:
self.get_logger().report_text(msg)