diff --git a/clearml_agent/commands/worker.py b/clearml_agent/commands/worker.py index 566a030..d6f00d5 100644 --- a/clearml_agent/commands/worker.py +++ b/clearml_agent/commands/worker.py @@ -2057,6 +2057,7 @@ class Worker(ServiceCommandSection): requirements=[freeze, previous_reqs], docker_cmd=execution_info.docker_cmd if execution_info else None, python_version=getattr(self.package_api, 'python', ''), + cuda_version=self._session.config.get("agent.cuda_version"), source_folder=add_venv_folder_cache, exclude_sub_folders=['task_repository', 'code']) @@ -2402,6 +2403,7 @@ class Worker(ServiceCommandSection): requirements=cached_requirements, docker_cmd=execution_info.docker_cmd if execution_info else None, python_version=package_manager_params['python'], + cuda_version=self._session.config.get("agent.cuda_version"), destination_folder=Path(venv_dir) ): print('::: Using Cached environment {} :::'.format(self.package_api.get_last_used_entry_cache())) diff --git a/clearml_agent/helper/package/base.py b/clearml_agent/helper/package/base.py index 737465f..32d7e46 100644 --- a/clearml_agent/helper/package/base.py +++ b/clearml_agent/helper/package/base.py @@ -165,8 +165,8 @@ class PackageManager(object): def get_pip_version(cls): return cls._pip_version or '' - def get_cached_venv(self, requirements, docker_cmd, python_version, destination_folder): - # type: (Dict, Optional[Union[dict, str]], Optional[str], Path) -> Optional[Path] + def get_cached_venv(self, requirements, docker_cmd, python_version, cuda_version, destination_folder): + # type: (Dict, Optional[Union[dict, str]], Optional[str], Optional[str], Path) -> Optional[Path] """ Copy a cached copy of the venv (based on the requirements) into destination_folder. Return None if failed or cached entry does not exist @@ -174,17 +174,25 @@ class PackageManager(object): if not self._get_cache_manager(): return None - keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version) + keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version, cuda_version) return self._get_cache_manager().copy_cached_entry(keys, destination_folder) - def add_cached_venv(self, requirements, docker_cmd, python_version, source_folder, exclude_sub_folders=None): - # type: (Union[Dict, List[Dict]], Optional[Union[dict, str]], Optional[str], Path, Optional[List[str]]) -> () + def add_cached_venv( + self, + requirements, # type: Union[Dict, List[Dict]] + docker_cmd, # type: Optional[Union[dict, str]] + python_version, # type: Optional[str] + cuda_version, # type: Optional[str] + source_folder, # type: Path + exclude_sub_folders=None # type: Optional[List[str]] + ): + # type: (...) -> () """ Copy the local venv folder into the venv cache (keys are based on the requirements+python+docker). """ if not self._get_cache_manager(): return - keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version) + keys = self._generate_reqs_hash_keys(requirements, docker_cmd, python_version, cuda_version) return self._get_cache_manager().add_entry( keys=keys, source_folder=source_folder, exclude_sub_folders=exclude_sub_folders) @@ -204,8 +212,8 @@ class PackageManager(object): return self._get_cache_manager().get_last_copied_entry() @classmethod - def _generate_reqs_hash_keys(cls, requirements_list, docker_cmd, python_version): - # type: (Union[Dict, List[Dict]], Optional[Union[dict, str]], Optional[str]) -> List[str] + def _generate_reqs_hash_keys(cls, requirements_list, docker_cmd, python_version, cuda_version): + # type: (Union[Dict, List[Dict]], Optional[Union[dict, str]], Optional[str], Optional[str]) -> List[str] requirements_list = requirements_list or dict() if not isinstance(requirements_list, (list, tuple)): requirements_list = [requirements_list] @@ -231,9 +239,10 @@ class PackageManager(object): if p.strip(strip_chars) and not p.strip(strip_chars).startswith('#')]) if not pip_reqs and not conda_reqs: continue - hash_text = '{class_type}\n{docker_cmd}\n{python_version}\n{pip_reqs}\n{conda_reqs}'.format( + hash_text = '{class_type}\n{docker_cmd}\n{cuda_ver}\n{python_version}\n{pip_reqs}\n{conda_reqs}'.format( class_type=str(cls), docker_cmd=str(docker_cmd or ''), + cuda_ver=str(cuda_version or ''), python_version=str(python_version or ''), pip_reqs=str(pip_reqs or ''), conda_reqs=str(conda_reqs or ''),