Fix requirements handling and poetry support

This commit is contained in:
allegroai 2020-01-16 11:10:38 +02:00
parent 6912846326
commit b7e568e299
6 changed files with 127 additions and 25 deletions

View File

@ -64,6 +64,7 @@ from trains_agent.helper.console import ensure_text
from trains_agent.helper.package.base import PackageManager from trains_agent.helper.package.base import PackageManager
from trains_agent.helper.package.conda_api import CondaAPI from trains_agent.helper.package.conda_api import CondaAPI
from trains_agent.helper.package.horovod_req import HorovodRequirement from trains_agent.helper.package.horovod_req import HorovodRequirement
from trains_agent.helper.package.external_req import ExternalRequirements
from trains_agent.helper.package.pip_api.system import SystemPip from trains_agent.helper.package.pip_api.system import SystemPip
from trains_agent.helper.package.pip_api.venv import VirtualenvPip from trains_agent.helper.package.pip_api.venv import VirtualenvPip
from trains_agent.helper.package.poetry_api import PoetryConfig, PoetryAPI from trains_agent.helper.package.poetry_api import PoetryConfig, PoetryAPI
@ -287,6 +288,7 @@ class Worker(ServiceCommandSection):
PytorchRequirement, PytorchRequirement,
CythonRequirement, CythonRequirement,
HorovodRequirement, HorovodRequirement,
ExternalRequirements,
) )
# poll queues every _polling_interval seconds # poll queues every _polling_interval seconds
@ -350,7 +352,6 @@ class Worker(ServiceCommandSection):
self.is_venv_update = self._session.config.agent.venv_update.enabled self.is_venv_update = self._session.config.agent.venv_update.enabled
self.poetry = PoetryConfig(self._session) self.poetry = PoetryConfig(self._session)
self.poetry.initialize()
self.docker_image_func = None self.docker_image_func = None
self._docker_image = None self._docker_image = None
self._docker_arguments = None self._docker_arguments = None
@ -934,7 +935,7 @@ class Worker(ServiceCommandSection):
cached_requirements=requirements, cached_requirements=requirements,
cwd=vcs.location if vcs and vcs.location else directory, cwd=vcs.location if vcs and vcs.location else directory,
) )
freeze = self.freeze_task_environment() freeze = self.freeze_task_environment(requirements_manager=requirements_manager)
script_dir = directory script_dir = directory
# Summary # Summary
@ -1108,7 +1109,8 @@ class Worker(ServiceCommandSection):
# do not update the task packages if we are using conda, # do not update the task packages if we are using conda,
# it will most likely make the task environment unreproducible # it will most likely make the task environment unreproducible
freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None) freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None,
requirements_manager=requirements_manager)
script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix() script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix()
# run code # run code
@ -1371,13 +1373,17 @@ class Worker(ServiceCommandSection):
status_message=self._task_status_change_message, status_message=self._task_status_change_message,
) )
def freeze_task_environment(self, task_id=None): def freeze_task_environment(self, task_id=None, requirements_manager=None):
try: try:
freeze = self.package_api.freeze() freeze = self.package_api.freeze()
except Exception as e: except Exception as e:
print("Could not freeze installed packages") print("Could not freeze installed packages")
self.log_traceback(e) self.log_traceback(e)
return None return None
if requirements_manager:
freeze = requirements_manager.replace_back(freeze)
if not task_id: if not task_id:
return freeze return freeze
@ -1402,6 +1408,9 @@ class Worker(ServiceCommandSection):
if not repo_info: if not repo_info:
return None return None
try: try:
if not self.poetry.enabled:
return None
self.poetry.initialize(cwd=repo_info.root)
api = self.poetry.get_api(repo_info.root) api = self.poetry.get_api(repo_info.root)
if api.enabled: if api.enabled:
api.install() api.install()
@ -1447,7 +1456,7 @@ class Worker(ServiceCommandSection):
except Exception as e: except Exception as e:
self.log_traceback(e) self.log_traceback(e)
cached_requirements_failed = True cached_requirements_failed = True
raise ValueError("Could not install task requirements!") raise ValueError("Could not install task requirements!\n{}".format(e))
else: else:
self.log("task requirements installation passed") self.log("task requirements installation passed")
return return

View File

@ -107,10 +107,19 @@ class PackageManager(object):
self._cwd = value self._cwd = value
@classmethod @classmethod
def out_of_scope_install_package(cls, package_name): def out_of_scope_install_package(cls, package_name, *args):
if PackageManager._selected_manager is not None: if PackageManager._selected_manager is not None:
try: try:
return PackageManager._selected_manager._install(package_name) return PackageManager._selected_manager._install(package_name, *args)
except Exception: except Exception:
pass pass
return return
@classmethod
def out_of_scope_freeze(cls):
if PackageManager._selected_manager is not None:
try:
return PackageManager._selected_manager.freeze()
except Exception:
pass
return []

View File

@ -0,0 +1,60 @@
from collections import OrderedDict
from typing import Text
from .base import PackageManager
from .requirements import SimpleSubstitution
class ExternalRequirements(SimpleSubstitution):
name = "external_link"
def __init__(self, *args, **kwargs):
super(ExternalRequirements, self).__init__(*args, **kwargs)
self.post_install_req = []
self.post_install_req_lookup = OrderedDict()
def match(self, req):
# match both editable or code or unparsed
if not (not req.name or req.req and (req.req.editable or req.req.vcs)):
return False
if not req.req or not req.req.line or not req.req.line.strip() or req.req.line.strip().startswith('#'):
return False
return True
def post_install(self):
post_install_req = self.post_install_req
self.post_install_req = []
for req in post_install_req:
try:
freeze_base = PackageManager.out_of_scope_freeze() or ''
except:
freeze_base = ''
PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--no-deps")
try:
freeze_post = PackageManager.out_of_scope_freeze() or ''
package_name = list(set(freeze_post['pip']) - set(freeze_base['pip']))
if package_name and package_name[0] not in self.post_install_req_lookup:
self.post_install_req_lookup[package_name[0]] = req.req.line
except:
pass
PackageManager.out_of_scope_install_package(req.tostr(markers=False), "--ignore-installed")
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Store in post req install, and return nothing
self.post_install_req.append(req)
# mark skip package, we will install it in post install hook
return Text('')
def replace_back(self, list_of_requirements):
if 'pip' in list_of_requirements:
original_requirements = list_of_requirements['pip']
list_of_requirements['pip'] = [r for r in original_requirements
if r not in self.post_install_req_lookup]
list_of_requirements['pip'] += [self.post_install_req_lookup.get(r, '')
for r in self.post_install_req_lookup.keys() if r in original_requirements]
return list_of_requirements

View File

@ -35,7 +35,7 @@ class SystemPip(PackageManager):
self._install(*(packages + self.install_flags())) self._install(*(packages + self.install_flags()))
def _install(self, *args): def _install(self, *args):
self.run_with_env(('install',) + args) self.run_with_env(('install',) + args, cwd=self.cwd)
def uninstall_packages(self, *packages): def uninstall_packages(self, *packages):
self.run_with_env(('uninstall', '-y') + packages) self.run_with_env(('uninstall', '-y') + packages)

View File

@ -1,6 +1,7 @@
from functools import wraps from functools import wraps
import attr import attr
import sys
from pathlib2 import Path from pathlib2 import Path
from trains_agent.helper.process import Argv, DEVNULL from trains_agent.helper.process import Argv, DEVNULL
from trains_agent.session import Session, POETRY from trains_agent.session import Session, POETRY
@ -35,10 +36,12 @@ def prop_guard(prop, log_prop=None):
class PoetryConfig: class PoetryConfig:
def __init__(self, session): def __init__(self, session, interpreter=None):
# type: (Session) -> () # type: (Session, str) -> ()
self.session = session self.session = session
self._log = session.get_logger(__name__) self._log = session.get_logger(__name__)
self._python = interpreter or sys.executable
self._initialized = False
@property @property
def log(self): def log(self):
@ -53,7 +56,7 @@ class PoetryConfig:
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
func = kwargs.pop("func", Argv.get_output) func = kwargs.pop("func", Argv.get_output)
kwargs.setdefault("stdin", DEVNULL) kwargs.setdefault("stdin", DEVNULL)
argv = Argv("poetry", "-n", *args) argv = Argv(self._python, "-m", "poetry", *args)
self.log.debug("running: %s", argv) self.log.debug("running: %s", argv)
return func(argv, **kwargs) return func(argv, **kwargs)
@ -61,10 +64,12 @@ class PoetryConfig:
return self.run("config", *args, **kwargs) return self.run("config", *args, **kwargs)
@_guard_enabled @_guard_enabled
def initialize(self): def initialize(self, cwd=None):
self._config("settings.virtualenvs.in-project", "true") if not self._initialized:
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX) self._initialized = True
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS) self._config("--local", "virtualenvs.in-project", "true", cwd=cwd)
# self._config("repositories.{}".format(self.REPO_NAME), PYTHON_INDEX)
# self._config("http-basic.{}".format(self.REPO_NAME), *PYTHON_INDEX_CREDENTIALS)
def get_api(self, path): def get_api(self, path):
# type: (Path) -> PoetryAPI # type: (Path) -> PoetryAPI
@ -81,7 +86,7 @@ class PoetryAPI(object):
def install(self): def install(self):
# type: () -> bool # type: () -> bool
if self.enabled: if self.enabled:
self.config.run("install", cwd=str(self.path), func=Argv.check_call) self.config.run("install", "-n", cwd=str(self.path), func=Argv.check_call)
return True return True
return False return False
@ -92,7 +97,9 @@ class PoetryAPI(object):
) )
def freeze(self): def freeze(self):
return {"poetry": self.config.run("show", cwd=str(self.path)).splitlines()} lines = self.config.run("show", cwd=str(self.path)).splitlines()
lines = [[p for p in line.split(' ') if p] for line in lines]
return {"pip": [parts[0]+'=='+parts[1]+' # '+' '.join(parts[2:]) for parts in lines]}
def get_python_command(self, extra): def get_python_command(self, extra):
return Argv("poetry", "run", "python", *extra) return Argv("poetry", "run", "python", *extra)

View File

@ -8,7 +8,7 @@ from copy import deepcopy
from itertools import chain, starmap from itertools import chain, starmap
from operator import itemgetter from operator import itemgetter
from os import path from os import path
from typing import Text, List, Type, Optional, Tuple from typing import Text, List, Type, Optional, Tuple, Dict
from packaging import version as packaging_version from packaging import version as packaging_version
from pathlib2 import Path from pathlib2 import Path
@ -184,6 +184,13 @@ class SimpleSubstitution(RequirementSubstitution):
req.specs = [('==', version_number + self.suffix)] req.specs = [('==', version_number + self.suffix)]
return Text(req) return Text(req)
def replace_back(self, list_of_requirements): # type: (Dict) -> Dict
"""
:param list_of_requirements: {'pip': ['a==1.0', ]}
:return: {'pip': ['a==1.0', ]}
"""
return list_of_requirements
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class CudaSensitiveSubstitution(SimpleSubstitution): class CudaSensitiveSubstitution(SimpleSubstitution):
@ -235,15 +242,17 @@ class RequirementsManager(object):
return None return None
def replace(self, requirements): # type: (Text) -> Text def replace(self, requirements): # type: (Text) -> Text
def safe_parse(req_str):
try:
return next(parse(req_str))
except Exception as ex:
return Requirement(req_str)
parsed_requirements = tuple( parsed_requirements = tuple(
map( map(
MarkerRequirement, MarkerRequirement,
filter( [safe_parse(line) for line in (requirements.splitlines()
None, if isinstance(requirements, six.text_type) else requirements)]
parse(requirements)
if isinstance(requirements, six.text_type)
else (next(parse(line), None) for line in requirements)
)
) )
) )
if not parsed_requirements: if not parsed_requirements:
@ -280,6 +289,14 @@ class RequirementsManager(object):
except Exception as ex: except Exception as ex:
print('RequirementsManager handler {} raised exception: {}'.format(h, ex)) print('RequirementsManager handler {} raised exception: {}'.format(h, ex))
def replace_back(self, requirements):
for h in self.handlers:
try:
requirements = h.replace_back(requirements)
except Exception:
pass
return requirements
@staticmethod @staticmethod
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text) def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version'] # we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']