From 1f53a06299f90a9ff1d4bc6feebda1decaf5a4ba Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 18 Jun 2020 01:55:14 +0300 Subject: [PATCH] Add agent.force_git_ssh_protocol option to force all git links to ssh:// (issue #16) Add git user/pass credentials for pip git packages (git+http and git+ssh) (issue #22) --- docs/trains.conf | 11 +++++- .../backend_api/config/default/agent.conf | 6 ++- trains_agent/helper/base.py | 14 +++++++ trains_agent/helper/package/base.py | 8 ++-- trains_agent/helper/package/conda_api.py | 2 +- trains_agent/helper/package/external_req.py | 31 ++++++++++++++-- trains_agent/helper/package/horovod_req.py | 2 +- trains_agent/helper/package/pip_api/venv.py | 2 +- trains_agent/helper/package/pytorch.py | 4 +- trains_agent/helper/package/requirements.py | 19 ++++++++-- trains_agent/helper/repo.py | 37 +++++++++++++++++-- 11 files changed, 114 insertions(+), 22 deletions(-) diff --git a/docs/trains.conf b/docs/trains.conf index 56f68c2..80ae3fa 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -13,11 +13,13 @@ api { } agent { - # Set GIT user/pass credentials - # leave blank for GIT SSH credentials + # Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https) + # leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol) git_user="" git_pass="" + # Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank) + force_git_ssh_protocol: false # unique name of this worker, if None, created based on hostname:process_id # Overridden with os environment: TRAINS_WORKER_NAME @@ -109,6 +111,11 @@ agent { # optional arguments to pass to docker image # arguments: ["--ipc=host"] } + + # CUDA versions used for Conda setup & solving PyTorch wheel packages + # it Should be detected automatically. Override with os environment CUDA_VERSION / CUDNN_VERSION + # cuda_version: 10.1 + # cudnn_version: 7.6 } sdk { diff --git a/trains_agent/backend_api/config/default/agent.conf b/trains_agent/backend_api/config/default/agent.conf index ebc3aa0..db16d51 100644 --- a/trains_agent/backend_api/config/default/agent.conf +++ b/trains_agent/backend_api/config/default/agent.conf @@ -9,10 +9,14 @@ # worker_name: "trains-agent-machine1" worker_name: "" - # Set GIT user/pass credentials for cloning code, leave blank for GIT SSH credentials. + # Set GIT user/pass credentials (if user/pass are set, GIT protocol will be set to https) + # leave blank for GIT SSH credentials (set force_git_ssh_protocol=true to force SSH protocol) # git_user: "" # git_pass: "" + # Force GIT protocol to use SSH regardless of the git url (Assumes GIT user/pass are blank) + force_git_ssh_protocol: false + # Set the python version to use when creating the virtual environment and launching the experiment # Example values: "/usr/bin/python3" or "/usr/local/bin/python3.6" # The default is the python executing the trains_agent diff --git a/trains_agent/helper/base.py b/trains_agent/helper/base.py index 4aaedcc..a99e389 100644 --- a/trains_agent/helper/base.py +++ b/trains_agent/helper/base.py @@ -555,3 +555,17 @@ class ExecutionInfo(NonStrictAttrs): execution.working_dir = working_dir or "" return execution + + +class safe_furl(furl.furl): + + @property + def port(self): + return self._port + + @port.setter + def port(self, port): + """ + Any port value is valid + """ + self._port = port diff --git a/trains_agent/helper/package/base.py b/trains_agent/helper/package/base.py index af6d383..1fce1a6 100644 --- a/trains_agent/helper/package/base.py +++ b/trains_agent/helper/package/base.py @@ -111,10 +111,12 @@ class PackageManager(object): def out_of_scope_install_package(cls, package_name, *args): if PackageManager._selected_manager is not None: try: - return PackageManager._selected_manager._install(package_name, *args) + result = PackageManager._selected_manager._install(package_name, *args) + if result not in (0, None, True): + return False except Exception: - pass - return + return False + return True @classmethod def out_of_scope_freeze(cls): diff --git a/trains_agent/helper/package/conda_api.py b/trains_agent/helper/package/conda_api.py index ddda0c3..61dbecc 100644 --- a/trains_agent/helper/package/conda_api.py +++ b/trains_agent/helper/package/conda_api.py @@ -378,7 +378,7 @@ class CondaAPI(PackageManager): print(e) raise e - self.requirements_manager.post_install() + self.requirements_manager.post_install(self.session) return True def _parse_conda_result_bad_packges(self, result_dict): diff --git a/trains_agent/helper/package/external_req.py b/trains_agent/helper/package/external_req.py index e51165c..52f2a86 100644 --- a/trains_agent/helper/package/external_req.py +++ b/trains_agent/helper/package/external_req.py @@ -3,6 +3,7 @@ from typing import Text from .base import PackageManager from .requirements import SimpleSubstitution +from ..base import safe_furl as furl class ExternalRequirements(SimpleSubstitution): @@ -22,7 +23,7 @@ class ExternalRequirements(SimpleSubstitution): return False return True - def post_install(self): + def post_install(self, session): post_install_req = self.post_install_req self.post_install_req = [] for req in post_install_req: @@ -30,7 +31,30 @@ class ExternalRequirements(SimpleSubstitution): freeze_base = PackageManager.out_of_scope_freeze() or '' except: freeze_base = '' - PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--no-deps") + + req_line = req.tostr(markers=False) + if req.req.vcs and req_line.startswith('git+'): + try: + url_no_frag = furl(req_line) + url_no_frag.set(fragment=None) + # reverse replace + fragment = req_line[::-1].replace(url_no_frag.url[::-1], '', 1)[::-1] + vcs_url = req_line[4:] + # reverse replace + vcs_url = vcs_url[::-1].replace(fragment[::-1], '', 1)[::-1] + from ..repo import Git + vcs = Git(session=session, url=vcs_url, location=None, revision=None) + vcs._set_ssh_url() + new_req_line = 'git+{}{}'.format(vcs.url_with_auth, fragment) + if new_req_line != req_line: + url_pass = furl(new_req_line).password + print('Replacing original pip vcs \'{}\' with \'{}\''.format( + req_line, new_req_line.replace(url_pass, '****', 1) if url_pass else new_req_line)) + req_line = new_req_line + except Exception: + print('WARNING: Failed parsing pip git install, using original line {}'.format(req_line)) + + PackageManager.out_of_scope_install_package(req_line, "--no-deps") try: freeze_post = PackageManager.out_of_scope_freeze() or '' package_name = list(set(freeze_post['pip']) - set(freeze_base['pip'])) @@ -38,7 +62,8 @@ class ExternalRequirements(SimpleSubstitution): self.post_install_req_lookup[package_name[0]] = req.req.line except: pass - PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--ignore-installed") + if not PackageManager.out_of_scope_install_package(req_line, "--ignore-installed"): + raise ValueError("Failed installing GIT/HTTPs package \'{}\'".format(req_line)) def replace(self, req): """ diff --git a/trains_agent/helper/package/horovod_req.py b/trains_agent/helper/package/horovod_req.py index ff1be9d..55e61c6 100644 --- a/trains_agent/helper/package/horovod_req.py +++ b/trains_agent/helper/package/horovod_req.py @@ -16,7 +16,7 @@ class HorovodRequirement(SimpleSubstitution): # match both horovod return req.name and self.name == req.name.lower() - def post_install(self): + def post_install(self, session): if self.post_install_req: PackageManager.out_of_scope_install_package(self.post_install_req.tostr(markers=False)) self.post_install_req = None diff --git a/trains_agent/helper/package/pip_api/venv.py b/trains_agent/helper/package/pip_api/venv.py index f9c14c2..dc97b01 100644 --- a/trains_agent/helper/package/pip_api/venv.py +++ b/trains_agent/helper/package/pip_api/venv.py @@ -37,7 +37,7 @@ class VirtualenvPip(SystemPip, PackageManager): if isinstance(requirements, dict) and requirements.get("pip"): requirements["pip"] = self.requirements_manager.replace(requirements["pip"]) super(VirtualenvPip, self).load_requirements(requirements) - self.requirements_manager.post_install() + self.requirements_manager.post_install(self.session) def create_flags(self): """ diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 7ed30b5..f58997f 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -270,7 +270,7 @@ class PytorchRequirement(SimpleSubstitution): def get_url_for_platform(self, req): # check if package is already installed with system packages try: - if self.config.get("agent.package_manager.system_site_packages"): + if self.config.get("agent.package_manager.system_site_packages", None): from pip._internal.commands.show import search_packages_info installed_torch = list(search_packages_info([req.name])) # notice the comparision order, the first part will make sure we have a valid installed package @@ -295,7 +295,7 @@ class PytorchRequirement(SimpleSubstitution): torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version) url = self._get_link_from_torch_page(req, torch_url) - if not url and self.config.get("agent.package_manager.torch_nightly"): + if not url and self.config.get("agent.package_manager.torch_nightly", None): torch_url, torch_url_key = SimplePytorchRequirement.get_torch_page(self.cuda_version, nightly=True) url = self._get_link_from_torch_page(req, torch_url) # try one more time, with a lower cuda version (never fallback to CPU): diff --git a/trains_agent/helper/package/requirements.py b/trains_agent/helper/package/requirements.py index 9728bbd..ca47b6a 100644 --- a/trains_agent/helper/package/requirements.py +++ b/trains_agent/helper/package/requirements.py @@ -54,7 +54,17 @@ class MarkerRequirement(object): if self.specifier: parts.append(self.format_specs()) - + elif self.vcs: + # leave the line as is, let pip handle it + if self.line: + parts = [self.line] + else: + # let's build the line manually + parts = [ + self.uri, + '@{}'.format(self.revision) if self.revision else '', + '#subdirectory={}'.format(self.subdirectory) if self.subdirectory else '' + ] else: parts = [self.uri] @@ -316,7 +326,7 @@ class RequirementSubstitution(object): """ pass - def post_install(self): + def post_install(self, session): pass @classmethod @@ -472,12 +482,13 @@ class RequirementsManager(object): result = map(self.translator.translate, result) return join_lines(result) - def post_install(self): + def post_install(self, session): for h in self.handlers: try: - h.post_install() + h.post_install(session) except Exception as ex: print('RequirementsManager handler {} raised exception: {}'.format(h, ex)) + raise def replace_back(self, requirements): for h in self.handlers: diff --git a/trains_agent/helper/repo.py b/trains_agent/helper/repo.py index 2d5c812..4343704 100644 --- a/trains_agent/helper/repo.py +++ b/trains_agent/helper/repo.py @@ -97,7 +97,7 @@ class VCS(object): :param session: program session :param url: repository url :param location: (desired) clone location - :param: desired clone revision + :param revision: desired clone revision """ self.session = session self.log = self.session.get_logger( @@ -208,7 +208,7 @@ class VCS(object): ) @classmethod - def resolve_ssh_url(cls, url): + def replace_ssh_url(cls, url): # type: (Text) -> Text """ Replace SSH URL with HTTPS URL when applicable @@ -242,11 +242,37 @@ class VCS(object): ).url return url + @classmethod + def replace_http_url(cls, url): + # type: (Text) -> Text + """ + Replace HTTPS URL with SSH URL when applicable + """ + parsed_url = furl(url) + if parsed_url.scheme == "https": + parsed_url.scheme = "ssh" + parsed_url.username = "git" + parsed_url.password = None + # make sure there is no port in the final url (safe_furl support) + parsed_url.port = None + url = parsed_url.url + return url + def _set_ssh_url(self): """ Replace instance URL with SSH substitution result and report to log. According to ``man ssh-add``, ``SSH_AUTH_SOCK`` must be set in order for ``ssh-add`` to work. """ + if self.session.config.get('agent.force_git_ssh_protocol', None) and self.url: + parsed_url = furl(self.url) + if parsed_url.scheme == "https": + new_url = self.replace_http_url(self.url) + if new_url != self.url: + print("Using SSH credentials - replacing https url '{}' with ssh url '{}'".format( + self.url, new_url)) + self.url = new_url + return + if not self.session.config.agent.translate_ssh: return @@ -255,7 +281,7 @@ class VCS(object): (ENV_AGENT_GIT_USER.get() or self.session.config.get('agent.git_user', None)) and (ENV_AGENT_GIT_PASS.get() or self.session.config.get('agent.git_pass', None)) ): - new_url = self.resolve_ssh_url(self.url) + new_url = self.replace_ssh_url(self.url) if new_url != self.url: print("Using user/pass credentials - replacing ssh url '{}' with https url '{}'".format( self.url, new_url)) @@ -396,7 +422,10 @@ class VCS(object): Add username and password to URL if missing from URL and present in config. Does not modify ssh URLs. """ - parsed_url = furl(url) + try: + parsed_url = furl(url) + except ValueError: + return url if parsed_url.scheme in ["", "ssh"] or parsed_url.scheme.startswith("git"): return parsed_url.url config_user = ENV_AGENT_GIT_USER.get() or config.get("agent.{}_user".format(cls.executable_name), None)