Fix cache to take cuda version into account

This commit is contained in:
allegroai 2021-02-11 14:47:05 +02:00
parent 410cc8c7be
commit b22d926d94
2 changed files with 20 additions and 9 deletions

View File

@ -2057,6 +2057,7 @@ class Worker(ServiceCommandSection):
requirements=[freeze, previous_reqs],
docker_cmd=execution_info.docker_cmd if execution_info else None,
python_version=getattr(self.package_api, 'python', ''),
cuda_version=self._session.config.get("agent.cuda_version"),
source_folder=add_venv_folder_cache,
exclude_sub_folders=['task_repository', 'code'])
@ -2402,6 +2403,7 @@ class Worker(ServiceCommandSection):
requirements=cached_requirements,
docker_cmd=execution_info.docker_cmd if execution_info else None,
python_version=package_manager_params['python'],
cuda_version=self._session.config.get("agent.cuda_version"),
destination_folder=Path(venv_dir)
):
print('::: Using Cached environment {} :::'.format(self.package_api.get_last_used_entry_cache()))

View File

@ -165,8 +165,8 @@ class PackageManager(object):
def get_pip_version(cls):
return cls._pip_version or ''
def get_cached_venv(self, requirements, docker_cmd, python_version, destination_folder):
# type: (Dict, Optional[Union[dict, str]], Optional[str], Path) -> Optional[Path]
def get_cached_venv(self, requirements, docker_cmd, python_version, cuda_version, destination_folder):
# type: (Dict, Optional[Union[dict, str]], Optional[str], Optional[str], Path) -> Optional[Path]
"""
Copy a cached copy of the venv (based on the requirements) into destination_folder.
Return None if failed or cached entry does not exist
@ -174,17 +174,25 @@ class PackageManager(object):
if not self._get_cache_manager():
return None
keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version)
keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version, cuda_version)
return self._get_cache_manager().copy_cached_entry(keys, destination_folder)
def add_cached_venv(self, requirements, docker_cmd, python_version, source_folder, exclude_sub_folders=None):
# type: (Union[Dict, List[Dict]], Optional[Union[dict, str]], Optional[str], Path, Optional[List[str]]) -> ()
def add_cached_venv(
self,
requirements, # type: Union[Dict, List[Dict]]
docker_cmd, # type: Optional[Union[dict, str]]
python_version, # type: Optional[str]
cuda_version, # type: Optional[str]
source_folder, # type: Path
exclude_sub_folders=None # type: Optional[List[str]]
):
# type: (...) -> ()
"""
Copy the local venv folder into the venv cache (keys are based on the requirements+python+docker).
"""
if not self._get_cache_manager():
return
keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version)
keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version, cuda_version)
return self._get_cache_manager().add_entry(
keys=keys, source_folder=source_folder, exclude_sub_folders=exclude_sub_folders)
@ -204,8 +212,8 @@ class PackageManager(object):
return self._get_cache_manager().get_last_copied_entry()
@classmethod
def _generate_reqs_hash_keys(cls, requirements_list, docker_cmd, python_version):
# type: (Union[Dict, List[Dict]], Optional[Union[dict, str]], Optional[str]) -> List[str]
def _generate_reqs_hash_keys(cls, requirements_list, docker_cmd, python_version, cuda_version):
# type: (Union[Dict, List[Dict]], Optional[Union[dict, str]], Optional[str], Optional[str]) -> List[str]
requirements_list = requirements_list or dict()
if not isinstance(requirements_list, (list, tuple)):
requirements_list = [requirements_list]
@ -231,9 +239,10 @@ class PackageManager(object):
if p.strip(strip_chars) and not p.strip(strip_chars).startswith('#')])
if not pip_reqs and not conda_reqs:
continue
hash_text = '{class_type}\n{docker_cmd}\n{python_version}\n{pip_reqs}\n{conda_reqs}'.format(
hash_text = '{class_type}\n{docker_cmd}\n{cuda_ver}\n{python_version}\n{pip_reqs}\n{conda_reqs}'.format(
class_type=str(cls),
docker_cmd=str(docker_cmd or ''),
cuda_ver=str(cuda_version or ''),
python_version=str(python_version or ''),
pip_reqs=str(pip_reqs or ''),
conda_reqs=str(conda_reqs or ''),