Fix agent fails to check out code from main branch when branch/commit is not explicitly specified

This commit is contained in:
allegroai 2022-02-07 20:04:08 +02:00
parent bfed3ccf4d
commit 1f53c4fd1b
2 changed files with 42 additions and 32 deletions

View File

@ -2135,32 +2135,32 @@ class Worker(ServiceCommandSection):
cwd = vcs.location if vcs and vcs.location else directory cwd = vcs.location if vcs and vcs.location else directory
if is_cached and not standalone_mode: if not standalone_mode:
# reinstalling git / local packages if is_cached:
package_api = copy(self.package_api) # reinstalling git / local packages
OnlyExternalRequirements.cwd = package_api.cwd = cwd package_api = copy(self.package_api)
package_api.requirements_manager = self._get_requirements_manager( OnlyExternalRequirements.cwd = package_api.cwd = cwd
base_interpreter=package_api.requirements_manager.get_interpreter(), package_api.requirements_manager = self._get_requirements_manager(
requirement_substitutions=[OnlyExternalRequirements] base_interpreter=package_api.requirements_manager.get_interpreter(),
) requirement_substitutions=[OnlyExternalRequirements]
# make sure we run the handlers )
cached_requirements = \ # make sure we run the handlers
{k: package_api.requirements_manager.replace(requirements[k] or '') cached_requirements = \
for k in requirements} {k: package_api.requirements_manager.replace(requirements[k] or '')
if str(cached_requirements.get('pip', '')).strip() \ for k in requirements}
or str(cached_requirements.get('conda', '')).strip(): if str(cached_requirements.get('pip', '')).strip() \
package_api.load_requirements(cached_requirements) or str(cached_requirements.get('conda', '')).strip():
# make sure we call the correct freeze package_api.load_requirements(cached_requirements)
requirements_manager = package_api.requirements_manager # make sure we call the correct freeze
requirements_manager = package_api.requirements_manager
elif not is_cached and not standalone_mode: else:
self.install_requirements( self.install_requirements(
execution, execution,
repo_info, repo_info,
requirements_manager=requirements_manager, requirements_manager=requirements_manager,
cached_requirements=requirements, cached_requirements=requirements,
cwd=cwd, cwd=cwd,
) )
# do not update the task packages if we are using conda, # do not update the task packages if we are using conda,
# it will most likely make the task environment unreproducible # it will most likely make the task environment unreproducible

View File

@ -108,7 +108,7 @@ class VCS(object):
) )
self.url = url self.url = url
self.location = Text(location) self.location = Text(location)
self.revision = revision self._revision = revision
self.log = self.session.get_logger(__name__) self.log = self.session.get_logger(__name__)
@property @property
@ -390,7 +390,7 @@ class VCS(object):
""" """
Checkout repository at specified revision Checkout repository at specified revision
""" """
self.call("checkout", self.revision, *self.checkout_flags, cwd=self.location) self.call("checkout", self._revision, *self.checkout_flags, cwd=self.location)
@abc.abstractmethod @abc.abstractmethod
def pull(self): def pull(self):
@ -519,7 +519,7 @@ class VCS(object):
class Git(VCS): class Git(VCS):
executable_name = "git" executable_name = "git"
main_branch = "master" main_branch = ("master", "main")
clone_flags = ("--quiet", "--recursive") clone_flags = ("--quiet", "--recursive")
checkout_flags = ("--force",) checkout_flags = ("--force",)
COMMAND_ENV = { COMMAND_ENV = {
@ -531,7 +531,9 @@ class Git(VCS):
@staticmethod @staticmethod
def remote_branch_name(branch): def remote_branch_name(branch):
return "origin/{}".format(branch) return [
"origin/{}".format(b) for b in ([branch] if isinstance(branch, str) else branch)
]
def executable_not_found_error_help(self): def executable_not_found_error_help(self):
return 'Cannot find "{}" executable. {}'.format( return 'Cannot find "{}" executable. {}'.format(
@ -553,7 +555,15 @@ class Git(VCS):
""" """
Checkout repository at specified revision Checkout repository at specified revision
""" """
self.call("checkout", self.revision, *self.checkout_flags, cwd=self.location) revisions = [self._revision] if isinstance(self._revision, str) else self._revision
for i, revision in enumerate(revisions):
try:
self.call("checkout", revision, *self.checkout_flags, cwd=self.location)
break
except subprocess.CalledProcessError:
if i == len(revisions) - 1:
raise
try: try:
self.call("submodule", "update", "--recursive", cwd=self.location) self.call("submodule", "update", "--recursive", cwd=self.location)
except: # noqa except: # noqa
@ -593,7 +603,7 @@ class Hg(VCS):
"pull", "pull",
self.url_with_auth, self.url_with_auth,
cwd=self.location, cwd=self.location,
*(("-r", self.revision) if self.revision else ()) *(("-r", self._revision) if self._revision else ())
) )
info_commands = dict( info_commands = dict(