Fix venv caching always reinstall git repositories and local repositories

This commit is contained in:
allegroai 2021-02-23 12:45:34 +02:00
parent 5da7184276
commit 0caf31719c
3 changed files with 64 additions and 7 deletions

View File

@ -12,7 +12,7 @@ import sys
import shutil
import traceback
from collections import defaultdict
from copy import deepcopy
from copy import deepcopy, copy
from datetime import datetime
from distutils.spawn import find_executable
from functools import partial, cmp_to_key
@ -73,7 +73,7 @@ from clearml_agent.helper.os.daemonize import daemonize_process
from clearml_agent.helper.package.base import PackageManager
from clearml_agent.helper.package.conda_api import CondaAPI
from clearml_agent.helper.package.post_req import PostRequirement
from clearml_agent.helper.package.external_req import ExternalRequirements
from clearml_agent.helper.package.external_req import ExternalRequirements, OnlyExternalRequirements
from clearml_agent.helper.package.pip_api.system import SystemPip
from clearml_agent.helper.package.pip_api.venv import VirtualenvPip
from clearml_agent.helper.package.poetry_api import PoetryConfig, PoetryAPI
@ -440,11 +440,12 @@ class Worker(ServiceCommandSection):
return kwargs
def _get_requirements_manager(self, os_override=None, base_interpreter=None):
def _get_requirements_manager(self, os_override=None, base_interpreter=None, requirement_substitutions=None):
requirements_manager = RequirementsManager(
self._session, base_interpreter=base_interpreter
)
for requirement_cls in self._requirement_substitutions:
requirement_substitutions = requirement_substitutions or self._requirement_substitutions
for requirement_cls in requirement_substitutions:
if os_override and issubclass(requirement_cls, PytorchRequirement):
requirement_cls = partial(requirement_cls, os_override=os_override)
requirements_manager.register(requirement_cls)
@ -1468,7 +1469,19 @@ class Worker(ServiceCommandSection):
directory, vcs, repo_info = self.get_repo_info(execution, current_task, venv_folder.as_posix())
if not is_cached:
if is_cached:
# reinstalling git / local packages
package_api = copy(self.package_api)
package_api.requirements_manager = self._get_requirements_manager(
base_interpreter=package_api.requirements_manager.get_interpreter(),
requirement_substitutions=[OnlyExternalRequirements]
)
# make sure we run the handlers
cached_requirements = \
{k: package_api.requirements_manager.replace(requirements[k] or '')
for k in requirements}
package_api.load_requirements(cached_requirements)
else:
self.install_requirements(
execution,
repo_info,
@ -1477,6 +1490,7 @@ class Worker(ServiceCommandSection):
cwd=vcs.location if vcs and vcs.location else directory,
package_api=self.global_package_api if install_globally else None,
)
freeze = self.freeze_task_environment(
task_id=task_id, requirements_manager=requirements_manager, update_requirements=False)
script_dir = directory
@ -1721,7 +1735,20 @@ class Worker(ServiceCommandSection):
print("\n")
if not is_cached and not standalone_mode:
if is_cached and not standalone_mode:
# reinstalling git / local packages
package_api = copy(self.package_api)
package_api.requirements_manager = self._get_requirements_manager(
base_interpreter=package_api.requirements_manager.get_interpreter(),
requirement_substitutions=[OnlyExternalRequirements]
)
# make sure we run the handlers
cached_requirements = \
{k: package_api.requirements_manager.replace(requirements[k] or '')
for k in requirements}
package_api.load_requirements(cached_requirements)
elif not is_cached and not standalone_mode:
self.install_requirements(
execution,
repo_info,
@ -2376,6 +2403,7 @@ class Worker(ServiceCommandSection):
if not standalone_mode:
rm_tree(normalize_path(venv_dir, WORKING_REPOSITORY_DIR))
package_manager_params = dict(
session=self._session,
python=executable_version_suffix if self.is_conda else executable_name,

View File

@ -17,6 +17,15 @@ class ExternalRequirements(SimpleSubstitution):
self.post_install_req_lookup = OrderedDict()
def match(self, req):
# match local folder building:
# noinspection PyBroadException
try:
if not req.name and req.req and not req.req.editable and not req.req.vcs and \
req.req.line and not req.req.line.strip().split('#')[0].lower().endswith('.whl'):
return True
except Exception:
pass
# match both editable or code or unparsed
if not (not req.name or req.req and (req.req.editable or req.req.vcs)):
return False
@ -104,3 +113,20 @@ class ExternalRequirements(SimpleSubstitution):
list_of_requirements[k] += [self.post_install_req_lookup.get(r, '')
for r in self.post_install_req_lookup.keys() if r in original_requirements]
return list_of_requirements
class OnlyExternalRequirements(ExternalRequirements):
def __init__(self, *args, **kwargs):
super(OnlyExternalRequirements, self).__init__(*args, **kwargs)
def match(self, req):
return not super(OnlyExternalRequirements, self).match(req)
def replace(self, req):
"""
Replace a requirement
:raises: ValueError if version is pre-release
"""
# Do not store the skipped requirements
# mark skip package
return Text('')

View File

@ -447,6 +447,7 @@ class RequirementsManager(object):
'cu'+agent['cuda_version'] if self.found_cuda else 'cpu')
self.translator = RequirementsTranslator(session, interpreter=base_interpreter,
cache_dir=pip_cache_dir.as_posix())
self._base_interpreter = base_interpreter
def register(self, cls): # type: (Type[RequirementSubstitution]) -> None
self.handlers.append(cls(self._session))
@ -530,6 +531,9 @@ class RequirementsManager(object):
pass
return requirements
def get_interpreter(self):
return self._base_interpreter
@staticmethod
def get_cuda_version(config): # type: (ConfigTree) -> (Text, Text)
# we assume os.environ already updated the config['agent.cuda_version'] & config['agent.cudnn_version']
@ -605,4 +609,3 @@ class RequirementsManager(object):
return (normalize_cuda_version(cuda_version or 0),
normalize_cuda_version(cudnn_version or 0))