Add agent.force_git_ssh_protocol option to force all git links to ssh:// (issue #16)

Add git user/pass credentials for pip git packages (git+http and  git+ssh) (issue #22)
This commit is contained in:
allegroai 2020-06-18 01:55:14 +03:00
parent 257dd95401
commit 1f53a06299
11 changed files with 114 additions and 22 deletions

View File

@ -13,11 +13,13 @@ api {
}
agent {
# Set GIT user/pass credentials
# leave blank for GIT SSH credentials
# Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https)
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
git_user=""
git_pass=""
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
force_git_ssh_protocol: false
# unique name of this worker, if None, created based on hostname:process_id
# Overridden with os environment: TRAINS_WORKER_NAME
@ -109,6 +111,11 @@ agent {
# optional arguments to pass to docker image
# arguments: ["--ipc=host"]
}
# CUDA versions used for Conda setup & solving PyTorch wheel packages
# it Should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION
# cuda_version: 10.1
# cudnn_version: 7.6
}
sdk {

View File

@ -9,10 +9,14 @@
# worker_name: "trains-agent-machine1"
worker_name: ""
# Set GIT user/pass credentials for cloning code, leave blank for GIT SSH credentials.
# Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https)
# leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol)
# git_user: ""
# git_pass: ""
# Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank)
force_git_ssh_protocol: false
# Set the python version to use when creating the virtual environment and launching the experiment
# Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6"
# The default is the python executing the trains_agent

View File

@ -555,3 +555,17 @@ class ExecutionInfo(NonStrictAttrs):
execution.working_dir = working_dir or ""
return execution
class safe_furl(furl.furl):
@property
def port(self):
return self._port
@port.setter
def port(self, port):
"""
Any port value is valid
"""
self._port = port

View File

@ -111,10 +111,12 @@ class PackageManager(object):
def out_of_scope_install_package(cls, package_name, *args):
if PackageManager._selected_manager is not None:
try:
return PackageManager._selected_manager._install(package_name, *args)
result = PackageManager._selected_manager._install(package_name, *args)
if result not in (0, None, True):
return False
except Exception:
pass
return
return False
return True
@classmethod
def out_of_scope_freeze(cls):

View File

@ -378,7 +378,7 @@ class CondaAPI(PackageManager):
print(e)
raise e
self.requirements_manager.post_install()
self.requirements_manager.post_install(self.session)
return True
def _parse_conda_result_bad_packges(self, result_dict):

View File

@ -3,6 +3,7 @@ from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
from ..base import safe_furl as furl
class ExternalRequirements(SimpleSubstitution):
@ -22,7 +23,7 @@ class ExternalRequirements(SimpleSubstitution):
return False
return True
def post_install(self):
def post_install(self, session):
post_install_req = self.post_install_req
self.post_install_req = []
for req in post_install_req:
@ -30,7 +31,30 @@ class ExternalRequirements(SimpleSubstitution):
freeze_base = PackageManager.out_of_scope_freeze() or ''
except:
freeze_base = ''
PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--no-deps")
req_line = req.tostr(markers=False)
if req.req.vcs and req_line.startswith('git+'):
try:
url_no_frag = furl(req_line)
url_no_frag.set(fragment=None)
# reverse replace
fragment = req_line[::-1].replace(url_no_frag.url[::-1], '', 1)[::-1]
vcs_url = req_line[4:]
# reverse replace
vcs_url = vcs_url[::-1].replace(fragment[::-1], '', 1)[::-1]
from ..repo import Git
vcs = Git(session=session, url=vcs_url, location=None, revision=None)
vcs._set_ssh_url()
new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment)
if new_req_line != req_line:
url_pass = furl(new_req_line).password
print('Replacing original pip vcs \'{}\' with \'{}\''.format(
req_line, new_req_line.replace(url_pass, '****', 1) if url_pass else new_req_line))
req_line = new_req_line
except Exception:
print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line))
PackageManager.out_of_scope_install_package(req_line, "--no-deps")
try:
freeze_post = PackageManager.out_of_scope_freeze() or ''
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
@ -38,7 +62,8 @@ class ExternalRequirements(SimpleSubstitution):
self.post_install_req_lookup[package_name[0]] = req.req.line
except:
pass
PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--ignore-installed")
if not PackageManager.out_of_scope_install_package(req_line, "--ignore-installed"):
raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line))
def replace(self, req):
"""

View File

@ -16,7 +16,7 @@ class HorovodRequirement(SimpleSubstitution):
# match both horovod
return req.name and self.name == req.name.lower()
def post_install(self):
def post_install(self, session):
if self.post_install_req:
PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False))
self.post_install_req = None

View File

@ -37,7 +37,7 @@ class VirtualenvPip(SystemPip, PackageManager):
if isinstance(requirements, dict) and requirements.get("pip"):
requirements["pip"] = self.requirements_manager.replace(requirements["pip"])
super(VirtualenvPip, self).load_requirements(requirements)
self.requirements_manager.post_install()
self.requirements_manager.post_install(self.session)
def create_flags(self):
"""

View File

@ -270,7 +270,7 @@ class PytorchRequirement(SimpleSubstitution):
def get_url_for_platform(self, req):
# check if package is already installed with system packages
try:
if self.config.get("agent.package_manager.system_site_packages"):
if self.config.get("agent.package_manager.system_site_packages", None):
from pip._internal.commands.show import search_packages_info
installed_torch = list(search_packages_info([req.name]))
# notice the comparision order, the first part will make sure we have a valid installed package
@ -295,7 +295,7 @@ class PytorchRequirement(SimpleSubstitution):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version)
url = self._get_link_from_torch_page(req, torch_url)
if not url and self.config.get("agent.package_manager.torch_nightly"):
if not url and self.config.get("agent.package_manager.torch_nightly", None):
torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True)
url = self._get_link_from_torch_page(req, torch_url)
# try one more time, with a lower cuda version (never fallback to CPU):

View File

@ -54,7 +54,17 @@ class MarkerRequirement(object):
if self.specifier:
parts.append(self.format_specs())
elif self.vcs:
# leave the line as is, let pip handle it
if self.line:
parts = [self.line]
else:
# let's build the line manually
parts = [
self.uri,
'@{}'.format(self.revision) if self.revision else '',
'#subdirectory={}'.format(self.subdirectory) if self.subdirectory else ''
]
else:
parts = [self.uri]
@ -316,7 +326,7 @@ class RequirementSubstitution(object):
"""
pass
def post_install(self):
def post_install(self, session):
pass
@classmethod
@ -472,12 +482,13 @@ class RequirementsManager(object):
result = map(self.translator.translate, result)
return join_lines(result)
def post_install(self):
def post_install(self, session):
for h in self.handlers:
try:
h.post_install()
h.post_install(session)
except Exception as ex:
print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
raise
def replace_back(self, requirements):
for h in self.handlers:

View File

@ -97,7 +97,7 @@ class VCS(object):
:param session: program session
:param url: repository url
:param location: (desired) clone location
:param: desired clone revision
:param revision: desired clone revision
"""
self.session = session
self.log = self.session.get_logger(
@ -208,7 +208,7 @@ class VCS(object):
)
@classmethod
def resolve_ssh_url(cls, url):
def replace_ssh_url(cls, url):
# type: (Text) -> Text
"""
Replace SSH URL with HTTPS URL when applicable
@ -242,11 +242,37 @@ class VCS(object):
).url
return url
@classmethod
def replace_http_url(cls, url):
# type: (Text) -> Text
"""
Replace HTTPS URL with SSH URL when applicable
"""
parsed_url = furl(url)
if parsed_url.scheme == "https":
parsed_url.scheme = "ssh"
parsed_url.username = "git"
parsed_url.password = None
# make sure there is no port in the final url (safe_furl support)
parsed_url.port = None
url = parsed_url.url
return url
def _set_ssh_url(self):
"""
Replace instance URL with SSH substitution result and report to log.
According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work.
"""
if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url:
parsed_url = furl(self.url)
if parsed_url.scheme == "https":
new_url = self.replace_http_url(self.url)
if new_url != self.url:
print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format(
self.url, new_url))
self.url = new_url
return
if not self.session.config.agent.translate_ssh:
return
@ -255,7 +281,7 @@ class VCS(object):
(ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and
(ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None))
):
new_url = self.resolve_ssh_url(self.url)
new_url = self.replace_ssh_url(self.url)
if new_url != self.url:
print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format(
self.url, new_url))
@ -396,7 +422,10 @@ class VCS(object):
Add username and password to URL if missing from URL and present in config.
Does not modify ssh URLs.
"""
parsed_url = furl(url)
try:
parsed_url = furl(url)
except ValueError:
return url
if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"):
return parsed_url.url
config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)