Add sdk.development.store_code_diff_from_remote (default False) to store diff from remote HEAD instead of local HEAD (issue #222)

This commit is contained in:
allegroai 2020-10-30 09:55:54 +02:00
parent a0ec4b895b
commit 753b3ff68c
4 changed files with 56 additions and 13 deletions

View File

@ -147,6 +147,7 @@ sdk {
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode # Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section # This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff: true store_uncommitted_code_diff: true
store_code_diff_from_remote: false
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset # Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: true support_stopping: true

View File

@ -51,6 +51,7 @@ class Detector(object):
""" """
_fallback = '_fallback' _fallback = '_fallback'
_remote = '_remote'
@attr.s @attr.s
class Commands(object): class Commands(object):
@ -66,6 +67,10 @@ class Detector(object):
# alternative commands # alternative commands
branch_fallback = attr.ib(default=None, type=list) branch_fallback = attr.ib(default=None, type=list)
diff_fallback = attr.ib(default=None, type=list) diff_fallback = attr.ib(default=None, type=list)
# remote commands
commit_remote = attr.ib(default=None, type=list)
diff_remote = attr.ib(default=None, type=list)
diff_fallback_remote = attr.ib(default=None, type=list)
def __init__(self, type_name, name=None): def __init__(self, type_name, name=None):
self.type_name = type_name self.type_name = type_name
@ -75,14 +80,14 @@ class Detector(object):
""" Returns a RepoInfo instance containing a command for each info attribute """ """ Returns a RepoInfo instance containing a command for each info attribute """
return self.Commands() return self.Commands()
def _get_command_output(self, path, name, command, strip=True): def _get_command_output(self, path, name, command, commands=None, strip=True):
""" Run a command and return its output """ """ Run a command and return its output """
try: try:
return get_command_output(command, path, strip=strip) return get_command_output(command, path, strip=strip)
except (CalledProcessError, UnicodeDecodeError) as ex: except (CalledProcessError, UnicodeDecodeError) as ex:
if not name.endswith(self._fallback): if not name.endswith(self._fallback):
fallback_command = attr.asdict(self._get_commands()).get(name + self._fallback) fallback_command = attr.asdict(commands or self._get_commands()).get(name + self._fallback)
if fallback_command: if fallback_command:
try: try:
return get_command_output(fallback_command, path, strip=strip) return get_command_output(fallback_command, path, strip=strip)
@ -97,11 +102,12 @@ class Detector(object):
) )
return "" return ""
def _get_info(self, path, include_diff=False): def _get_info(self, path, include_diff=False, diff_from_remote=False):
""" """
Get repository information. Get repository information.
:param path: Path to repository :param path: Path to repository
:param include_diff: Whether to include the diff command's output (if available) :param include_diff: Whether to include the diff command's output (if available)
:param diff_from_remote: Whether to store the remote diff/commit based on the remote commit (not local commit)
:return: RepoInfo instance :return: RepoInfo instance
""" """
path = str(path) path = str(path)
@ -109,28 +115,56 @@ class Detector(object):
if not include_diff: if not include_diff:
commands.diff = None commands.diff = None
# skip the local commands
if diff_from_remote and commands:
for name, command in attr.asdict(commands).items():
if name.endswith(self._remote) and command:
setattr(commands, name[:-len(self._remote)], None)
info = Result( info = Result(
**{ **{
name: self._get_command_output(path, name, command, strip=bool(name != 'diff')) name: self._get_command_output(path, name, command, commands=commands, strip=bool(name != 'diff'))
for name, command in attr.asdict(commands).items() for name, command in attr.asdict(commands).items()
if command and not name.endswith(self._fallback) if command and not name.endswith(self._fallback) and not name.endswith(self._remote)
} }
) )
if diff_from_remote and commands:
for name, command in attr.asdict(commands).items():
if name.endswith(self._remote) and command:
setattr(commands, name[:-len(self._remote)], command+[info.branch])
info = attr.assoc(
info,
**{
name[:-len(self._remote)]: self._get_command_output(
path, name[:-len(self._remote)], command + [info.branch],
commands=commands, strip=name.startswith('diff'))
for name, command in attr.asdict(commands).items()
if command and (
name.endswith(self._remote) and
not name[:-len(self._remote)].endswith(self._fallback)
)
}
)
# make sure we match the modified with the git remote diff state
info.modified = bool(info.diff)
return info return info
def _post_process_info(self, info): def _post_process_info(self, info):
# check if there are uncommitted changes in the current repository # check if there are uncommitted changes in the current repository
return info return info
def get_info(self, path, include_diff=False): def get_info(self, path, include_diff=False, diff_from_remote=False):
""" """
Get repository information. Get repository information.
:param path: Path to repository :param path: Path to repository
:param include_diff: Whether to include the diff command's output (if available) :param include_diff: Whether to include the diff command's output (if available)
:param diff_from_remote: Whether to store the remote diff/commit based on the remote commit (not local commit)
:return: RepoInfo instance :return: RepoInfo instance
""" """
info = self._get_info(path, include_diff) info = self._get_info(path, include_diff, diff_from_remote=diff_from_remote)
return self._post_process_info(info) return self._post_process_info(info)
def _is_repo_type(self, script_path): def _is_repo_type(self, script_path):
@ -200,6 +234,9 @@ class GitDetector(Detector):
modified=["git", "ls-files", "-m"], modified=["git", "ls-files", "-m"],
branch_fallback=["git", "rev-parse", "--abbrev-ref", "HEAD"], branch_fallback=["git", "rev-parse", "--abbrev-ref", "HEAD"],
diff_fallback=["git", "diff"], diff_fallback=["git", "diff"],
diff_remote=["git", "diff", "--submodule=diff", ],
commit_remote=["git", "rev-parse", ],
diff_fallback_remote=["git", "diff", ],
) )
def _post_process_info(self, info): def _post_process_info(self, info):
@ -234,7 +271,7 @@ class EnvDetector(Detector):
except Exception: except Exception:
return Path.cwd() return Path.cwd()
def _get_info(self, _, include_diff=False): def _get_info(self, _, include_diff=False, diff_from_remote=None):
repository_url = VCS_REPOSITORY_URL.get() repository_url = VCS_REPOSITORY_URL.get()
if not repository_url: if not repository_url:

View File

@ -589,7 +589,8 @@ class ScriptInfo(object):
return '' return ''
@classmethod @classmethod
def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None): def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None,
uncommitted_from_remote=False):
jupyter_filepath = cls._get_jupyter_notebook_filename() jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath: if jupyter_filepath:
scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()] scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
@ -623,7 +624,8 @@ class ScriptInfo(object):
else: else:
try: try:
for i, d in enumerate(scripts_dir): for i, d in enumerate(scripts_dir):
repo_info = plugin.get_info(str(d), include_diff=check_uncommitted) repo_info = plugin.get_info(
str(d), include_diff=check_uncommitted, diff_from_remote=uncommitted_from_remote)
if not repo_info.is_empty(): if not repo_info.is_empty():
script_dir = d script_dir = d
script_path = scripts_path[i] script_path = scripts_path[i]
@ -697,13 +699,14 @@ class ScriptInfo(object):
script_requirements) script_requirements)
@classmethod @classmethod
def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None): def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None,
uncommitted_from_remote=False):
try: try:
if not filepaths: if not filepaths:
filepaths = [sys.argv[0], ] filepaths = [sys.argv[0], ]
return cls._get_script_info( return cls._get_script_info(
filepaths=filepaths, check_uncommitted=check_uncommitted, filepaths=filepaths, check_uncommitted=check_uncommitted,
create_requirements=create_requirements, log=log) create_requirements=create_requirements, log=log, uncommitted_from_remote=uncommitted_from_remote)
except Exception as ex: except Exception as ex:
if log: if log:
log.warning("Failed auto-detecting task repository: {}".format(ex)) log.warning("Failed auto-detecting task repository: {}".format(ex))

View File

@ -65,6 +65,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_force_requirements = {} _force_requirements = {}
_store_diff = config.get('development.store_uncommitted_code_diff', False) _store_diff = config.get('development.store_uncommitted_code_diff', False)
_store_remote_diff = config.get('development.store_code_diff_from_remote', False)
_offline_filename = 'task.json' _offline_filename = 'task.json'
class TaskTypes(Enum): class TaskTypes(Enum):
@ -276,7 +277,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
result, script_requirements = ScriptInfo.get( result, script_requirements = ScriptInfo.get(
filepaths=[self._calling_filename, sys.argv[0], ] filepaths=[self._calling_filename, sys.argv[0], ]
if ScriptInfo.is_running_from_module() else [sys.argv[0], self._calling_filename, ], if ScriptInfo.is_running_from_module() else [sys.argv[0], self._calling_filename, ],
log=self.log, create_requirements=False, check_uncommitted=self._store_diff log=self.log, create_requirements=False,
check_uncommitted=self._store_diff, uncommitted_from_remote=self._store_remote_diff
) )
for msg in result.warning_messages: for msg in result.warning_messages:
self.get_logger().report_text(msg) self.get_logger().report_text(msg)