Add agent.force_git_ssh_protocol option to force all git links to ssh:// (issue #16)

Add git user/pass credentials for pip git packages (git+http and  git+ssh) (issue #22)
This commit is contained in:
allegroai 2020-06-18 01:55:14 +03:00
parent 257dd95401
commit 1f53a06299
11 changed files with 114 additions and 22 deletions

View File

@ -13,11 +13,13 @@ api {
} }
agent { agent {
# Set GIT user/pass credentials # Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https)
# leave blank for GIT SSH credentials # leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
git_user="" git_user=""
git_pass="" git_pass=""
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
force_git_ssh_protocol: false
# unique name of this worker, if None, created based on hostname:process_id # unique name of this worker, if None, created based on hostname:process_id
# Overridden with os environment: TRAINS_WORKER_NAME # Overridden with os environment: TRAINS_WORKER_NAME
@ -109,6 +111,11 @@ agent {
# optional arguments to pass to docker image # optional arguments to pass to docker image
# arguments: ["--ipc=host"] # arguments: ["--ipc=host"]
} }
# CUDA versions used for Conda setup & solving PyTorch wheel packages
# it Should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
# cuda_version: 10.1
# cudnn_version: 7.6
} }
sdk { sdk {

View File

@ -9,10 +9,14 @@
# worker_name: "trains-agent-machine1" # worker_name: "trains-agent-machine1"
worker_name: "" worker_name: ""
# Set GIT user/pass credentials for cloning code, leave blank for GIT SSH credentials. # Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https)
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
# git_user: "" # git_user: ""
# git_pass: "" # git_pass: ""
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
force_git_ssh_protocol: false
# Set the python version to use when creating the virtual environment and launching the experiment # Set the python version to use when creating the virtual environment and launching the experiment
# Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6" # Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6"
# The default is the python executing the trains_agent # The default is the python executing the trains_agent

View File

@ -555,3 +555,17 @@ class ExecutionInfo(NonStrictAttrs):
execution.working_dir = working_dir or "" execution.working_dir = working_dir or ""
return execution return execution
class safe_furl(furl.furl):
@property
def port(self):
return self._port
@port.setter
def port(self, port):
"""
Any port value is valid
"""
self._port = port

View File

@ -111,10 +111,12 @@ class PackageManager(object):
def out_of_scope_install_package(cls, package_name, *args): def out_of_scope_install_package(cls, package_name, *args):
if PackageManager._selected_manager is not None: if PackageManager._selected_manager is not None:
try: try:
return PackageManager._selected_manager._install(package_name, *args) result = PackageManager._selected_manager._install(package_name, *args)
if result not in (0, None, True):
return False
except Exception: except Exception:
pass return False
return return True
@classmethod @classmethod
def out_of_scope_freeze(cls): def out_of_scope_freeze(cls):

View File

@ -378,7 +378,7 @@ class CondaAPI(PackageManager):
print(e) print(e)
raise e raise e
self.requirements_manager.post_install() self.requirements_manager.post_install(self.session)
return True return True
def _parse_conda_result_bad_packges(self, result_dict): def _parse_conda_result_bad_packges(self, result_dict):

View File

@ -3,6 +3,7 @@ from typing import Text
from .base import PackageManager from .base import PackageManager
from .requirements import SimpleSubstitution from .requirements import SimpleSubstitution
from ..base import safe_furl as furl
class ExternalRequirements(SimpleSubstitution): class ExternalRequirements(SimpleSubstitution):
@ -22,7 +23,7 @@ class ExternalRequirements(SimpleSubstitution):
return False return False
return True return True
def post_install(self): def post_install(self, session):
post_install_req = self.post_install_req post_install_req = self.post_install_req
self.post_install_req = [] self.post_install_req = []
for req in post_install_req: for req in post_install_req:
@ -30,7 +31,30 @@ class ExternalRequirements(SimpleSubstitution):
freeze_base = PackageManager.out_of_scope_freeze() or '' freeze_base = PackageManager.out_of_scope_freeze() or ''
except: except:
freeze_base = '' freeze_base = ''
PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--no-deps")
req_line = req.tostr(markers=False)
if req.req.vcs and req_line.startswith('git+'):
try:
url_no_frag = furl(req_line)
url_no_frag.set(fragment=None)
# reverse replace
fragment = req_line[::-1].replace(url_no_frag.url[::-1], '', 1)[::-1]
vcs_url = req_line[4:]
# reverse replace
vcs_url = vcs_url[::-1].replace(fragment[::-1], '', 1)[::-1]
from ..repo import Git
vcs = Git(session=session, url=vcs_url, location=None, revision=None)
vcs._set_ssh_url()
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
if new_req_line != req_line:
url_pass = furl(new_req_line).password
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
req_line, new_req_line.replace(url_pass, '****', 1) if url_pass else new_req_line))
req_line = new_req_line
except Exception:
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
try: try:
freeze_post = PackageManager.out_of_scope_freeze() or '' freeze_post = PackageManager.out_of_scope_freeze() or ''
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip'])) package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
@ -38,7 +62,8 @@ class ExternalRequirements(SimpleSubstitution):
self.post_install_req_lookup[package_name[0]] = req.req.line self.post_install_req_lookup[package_name[0]] = req.req.line
except: except:
pass pass
PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--ignore-installed") if not PackageManager.out_of_scope_install_package(req_line, "--ignore-installed"):
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
def replace(self, req): def replace(self, req):
""" """

View File

@ -16,7 +16,7 @@ class HorovodRequirement(SimpleSubstitution):
# match both horovod # match both horovod
return req.name and self.name == req.name.lower() return req.name and self.name == req.name.lower()
def post_install(self): def post_install(self, session):
if self.post_install_req: if self.post_install_req:
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False)) PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
self.post_install_req = None self.post_install_req = None

View File

@ -37,7 +37,7 @@ class VirtualenvPip(SystemPip, PackageManager):
if isinstance(requirements, dict) and requirements.get("pip"): if isinstance(requirements, dict) and requirements.get("pip"):
requirements["pip"] = self.requirements_manager.replace(requirements["pip"]) requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
super(VirtualenvPip, self).load_requirements(requirements) super(VirtualenvPip, self).load_requirements(requirements)
self.requirements_manager.post_install() self.requirements_manager.post_install(self.session)
def create_flags(self): def create_flags(self):
""" """

View File

@ -270,7 +270,7 @@ class PytorchRequirement(SimpleSubstitution):
def get_url_for_platform(self, req): def get_url_for_platform(self, req):
# check if package is already installed with system packages # check if package is already installed with system packages
try: try:
if self.config.get("agent.package_manager.system_site_packages"): if self.config.get("agent.package_manager.system_site_packages", None):
from pip._internal.commands.show import search_packages_info from pip._internal.commands.show import search_packages_info
installed_torch = list(search_packages_info([req.name])) installed_torch = list(search_packages_info([req.name]))
# notice the comparision order, the first part will make sure we have a valid installed package # notice the comparision order, the first part will make sure we have a valid installed package
@ -295,7 +295,7 @@ class PytorchRequirement(SimpleSubstitution):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url) url = self._get_link_from_torch_page(req, torch_url)
if not url and self.config.get("agent.package_manager.torch_nightly"): if not url and self.config.get("agent.package_manager.torch_nightly", None):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True) torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
url = self._get_link_from_torch_page(req, torch_url) url = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version (never fallback to CPU): # try one more time, with a lower cuda version (never fallback to CPU):

View File

@ -54,7 +54,17 @@ class MarkerRequirement(object):
if self.specifier: if self.specifier:
parts.append(self.format_specs()) parts.append(self.format_specs())
elif self.vcs:
# leave the line as is, let pip handle it
if self.line:
parts = [self.line]
else:
# let's build the line manually
parts = [
self.uri,
'@{}'.format(self.revision) if self.revision else '',
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
]
else: else:
parts = [self.uri] parts = [self.uri]
@ -316,7 +326,7 @@ class RequirementSubstitution(object):
""" """
pass pass
def post_install(self): def post_install(self, session):
pass pass
@classmethod @classmethod
@ -472,12 +482,13 @@ class RequirementsManager(object):
result = map(self.translator.translate, result) result = map(self.translator.translate, result)
return join_lines(result) return join_lines(result)
def post_install(self): def post_install(self, session):
for h in self.handlers: for h in self.handlers:
try: try:
h.post_install() h.post_install(session)
except Exception as ex: except Exception as ex:
print('RequirementsManager handler {} raised exception: {}'.format(h, ex)) print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
raise
def replace_back(self, requirements): def replace_back(self, requirements):
for h in self.handlers: for h in self.handlers:

View File

@ -97,7 +97,7 @@ class VCS(object):
:param session: program session :param session: program session
:param url: repository url :param url: repository url
:param location: (desired) clone location :param location: (desired) clone location
:param: desired clone revision :param revision: desired clone revision
""" """
self.session = session self.session = session
self.log = self.session.get_logger( self.log = self.session.get_logger(
@ -208,7 +208,7 @@ class VCS(object):
) )
@classmethod @classmethod
def resolve_ssh_url(cls, url): def replace_ssh_url(cls, url):
# type: (Text) -> Text # type: (Text) -> Text
""" """
Replace SSH URL with HTTPS URL when applicable Replace SSH URL with HTTPS URL when applicable
@ -242,11 +242,37 @@ class VCS(object):
).url ).url
return url return url
@classmethod
def replace_http_url(cls, url):
# type: (Text) -> Text
"""
Replace HTTPS URL with SSH URL when applicable
"""
parsed_url = furl(url)
if parsed_url.scheme == "https":
parsed_url.scheme = "ssh"
parsed_url.username = "git"
parsed_url.password = None
# make sure there is no port in the final url (safe_furl support)
parsed_url.port = None
url = parsed_url.url
return url
def _set_ssh_url(self): def _set_ssh_url(self):
""" """
Replace instance URL with SSH substitution result and report to log. Replace instance URL with SSH substitution result and report to log.
According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work. According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work.
""" """
if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url:
parsed_url = furl(self.url)
if parsed_url.scheme == "https":
new_url = self.replace_http_url(self.url)
if new_url != self.url:
print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format(
self.url, new_url))
self.url = new_url
return
if not self.session.config.agent.translate_ssh: if not self.session.config.agent.translate_ssh:
return return
@ -255,7 +281,7 @@ class VCS(object):
(ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)) (ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None))
): ):
new_url = self.resolve_ssh_url(self.url) new_url = self.replace_ssh_url(self.url)
if new_url != self.url: if new_url != self.url:
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format( print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
self.url, new_url)) self.url, new_url))
@ -396,7 +422,10 @@ class VCS(object):
Add username and password to URL if missing from URL and present in config. Add username and password to URL if missing from URL and present in config.
Does not modify ssh URLs. Does not modify ssh URLs.
""" """
try:
parsed_url = furl(url) parsed_url = furl(url)
except ValueError:
return url
if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"): if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"):
return parsed_url.url return parsed_url.url
config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None) config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)