Add new pytorch no resolver mode and CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE to change resolver on a Task basis, now supports "pip", "direct", "none"

This commit is contained in:
allegroai 2023-09-02 17:45:10 +03:00
parent fb639afcb9
commit d16825029d
4 changed files with 33 additions and 10 deletions

View File

@ -80,14 +80,15 @@
# additional artifact repositories to use when installing python packages # additional artifact repositories to use when installing python packages
# extra_index_url: ["https://allegroai.jfrog.io/clearml/api/pypi/public/simple"] # extra_index_url: ["https://allegroai.jfrog.io/clearml/api/pypi/public/simple"]
# control the pytorch wheel resolving algorithm, options are: "pip", "direct" # control the pytorch wheel resolving algorithm, options are: "pip", "direct", "none"
# Override with environment variable CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE
# "pip" (default): would automatically detect the cuda version, and supply pip with the correct # "pip" (default): would automatically detect the cuda version, and supply pip with the correct
# extra-index-url, based on pytorch.org tables # extra-index-url, based on pytorch.org tables
# "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository # "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository
# and matching the automatically detected cuda version with the required pytorch wheel. # and matching the automatically detected cuda version with the required pytorch wheel.
# if the exact cuda version is not found for the required pytorch wheel, it will try # if the exact cuda version is not found for the required pytorch wheel, it will try
# a lower cuda version until a match is found # a lower cuda version until a match is found
# # "none": No resolver used, install pytorch like any other package
# pytorch_resolve: "pip" # pytorch_resolve: "pip"
# additional conda channels to use when installing with conda package manager # additional conda channels to use when installing with conda package manager

View File

@ -238,6 +238,8 @@ ENV_CUSTOM_BUILD_SCRIPT = EnvironmentConfig("CLEARML_AGENT_CUSTOM_BUILD_SCRIPT")
standard flow. standard flow.
""" """
ENV_PACKAGE_PYTORCH_RESOLVE = EnvironmentConfig("CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE")
class FileBuffering(IntEnum): class FileBuffering(IntEnum):
""" """

View File

@ -16,6 +16,7 @@ import six
from .requirements import ( from .requirements import (
SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement, SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement,
compare_version_rules, ) compare_version_rules, )
from ...definitions import ENV_PACKAGE_PYTORCH_RESOLVE
from ...external.requirements_parser.requirement import Requirement from ...external.requirements_parser.requirement import Requirement
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"} OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
@ -174,6 +175,7 @@ class PytorchRequirement(SimpleSubstitution):
extra_index_url_template = 'https://download.pytorch.org/whl/cu{}/' extra_index_url_template = 'https://download.pytorch.org/whl/cu{}/'
nightly_extra_index_url_template = 'https://download.pytorch.org/whl/nightly/cu{}/' nightly_extra_index_url_template = 'https://download.pytorch.org/whl/nightly/cu{}/'
torch_index_url_lookup = {} torch_index_url_lookup = {}
resolver_types = ("pip", "direct", "none")
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
os_name = kwargs.pop("os_override", None) os_name = kwargs.pop("os_override", None)
@ -208,6 +210,13 @@ class PytorchRequirement(SimpleSubstitution):
if self.config.get("agent.package_manager.torch_url_template", None): if self.config.get("agent.package_manager.torch_url_template", None):
PytorchWheel.url_template = \ PytorchWheel.url_template = \
self.config.get("agent.package_manager.torch_url_template", None) self.config.get("agent.package_manager.torch_url_template", None)
self.resolve_algorithm = str(
ENV_PACKAGE_PYTORCH_RESOLVE.get() or
self.config.get("agent.package_manager.pytorch_resolve", "pip")).lower()
if self.resolve_algorithm not in self.resolver_types:
print("WARNING: agent.package_manager.pytorch_resolve=={} not in {} reverting to '{}'".format(
self.resolve_algorithm, self.resolver_types, self.resolver_types[0]))
self.resolve_algorithm = self.resolver_types[0]
def _init_python_ver_cuda_ver(self): def _init_python_ver_cuda_ver(self):
if self.cuda is None: if self.cuda is None:
@ -261,6 +270,10 @@ class PytorchRequirement(SimpleSubstitution):
) )
def match(self, req): def match(self, req):
if self.resolve_algorithm == "none":
# skipping resolver
return False
return req.name in self.packages return req.name in self.packages
@staticmethod @staticmethod
@ -347,8 +360,10 @@ class PytorchRequirement(SimpleSubstitution):
from pip._internal.commands.show import search_packages_info from pip._internal.commands.show import search_packages_info
installed_torch = list(search_packages_info([req.name])) installed_torch = list(search_packages_info([req.name]))
# notice the comparison order, the first part will make sure we have a valid installed package # notice the comparison order, the first part will make sure we have a valid installed package
installed_torch_version = (getattr(installed_torch[0], 'version', None) or installed_torch[0]['version']) \ installed_torch_version = \
if installed_torch else None (getattr(installed_torch[0], 'version', None) or
installed_torch[0]['version']) if installed_torch else None
if installed_torch and installed_torch_version and \ if installed_torch and installed_torch_version and \
req.compare_version(installed_torch_version): req.compare_version(installed_torch_version):
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format( print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
@ -356,6 +371,7 @@ class PytorchRequirement(SimpleSubstitution):
# package already installed, do nothing # package already installed, do nothing
req.specs = [('==', str(installed_torch_version))] req.specs = [('==', str(installed_torch_version))]
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
except Exception: except Exception:
pass pass
@ -480,8 +496,11 @@ class PytorchRequirement(SimpleSubstitution):
# we first try to resolve things ourselves because pytorch pip is not always picking the correct # we first try to resolve things ourselves because pytorch pip is not always picking the correct
# versions from their pip repository # versions from their pip repository
resolve_algorithm = str(self.config.get("agent.package_manager.pytorch_resolve", "pip")).lower() resolve_algorithm = self.resolve_algorithm
if resolve_algorithm == "direct": if resolve_algorithm == "none":
# skipping resolver
return None
elif resolve_algorithm == "direct":
# noinspection PyBroadException # noinspection PyBroadException
try: try:
new_req = self._replace(req) new_req = self._replace(req)
@ -489,8 +508,8 @@ class PytorchRequirement(SimpleSubstitution):
self._original_req.append((req, new_req)) self._original_req.append((req, new_req))
return new_req return new_req
except Exception: except Exception:
pass print("Warning: Failed resolving using `pytorch_resolve=direct` reverting to `pytorch_resolve=pip`")
elif resolve_algorithm not in ("direct", "pip"): elif resolve_algorithm not in self.resolver_types:
print("Warning: `agent.package_manager.pytorch_resolve={}` " print("Warning: `agent.package_manager.pytorch_resolve={}` "
"unrecognized, default to `pip`".format(resolve_algorithm)) "unrecognized, default to `pip`".format(resolve_algorithm))

View File

@ -96,14 +96,15 @@ agent {
# additional flags to use when calling pip install, example: ["--use-deprecated=legacy-resolver", ] # additional flags to use when calling pip install, example: ["--use-deprecated=legacy-resolver", ]
# extra_pip_install_flags: [] # extra_pip_install_flags: []
# control the pytorch wheel resolving algorithm, options are: "pip", "direct" # control the pytorch wheel resolving algorithm, options are: "pip", "direct", "none"
# Override with environment variable CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE
# "pip" (default): would automatically detect the cuda version, and supply pip with the correct # "pip" (default): would automatically detect the cuda version, and supply pip with the correct
# extra-index-url, based on pytorch.org tables # extra-index-url, based on pytorch.org tables
# "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository # "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository
# and matching the automatically detected cuda version with the required pytorch wheel. # and matching the automatically detected cuda version with the required pytorch wheel.
# if the exact cuda version is not found for the required pytorch wheel, it will try # if the exact cuda version is not found for the required pytorch wheel, it will try
# a lower cuda version until a match is found # a lower cuda version until a match is found
# # "none": No resolver used, install pytorch like any other package
# pytorch_resolve: "pip" # pytorch_resolve: "pip"
# additional conda channels to use when installing with conda package manager # additional conda channels to use when installing with conda package manager