mirror of
https://github.com/clearml/clearml-agent
synced 2025-03-03 18:52:22 +00:00
Add new pytorch no resolver mode and CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE to change resolver on a Task basis, now supports "pip", "direct", "none"
This commit is contained in:
parent
fb639afcb9
commit
d16825029d
@ -80,14 +80,15 @@
|
|||||||
# additional artifact repositories to use when installing python packages
|
# additional artifact repositories to use when installing python packages
|
||||||
# extra_index_url: ["https://allegroai.jfrog.io/clearml/api/pypi/public/simple"]
|
# extra_index_url: ["https://allegroai.jfrog.io/clearml/api/pypi/public/simple"]
|
||||||
|
|
||||||
# control the pytorch wheel resolving algorithm, options are: "pip", "direct"
|
# control the pytorch wheel resolving algorithm, options are: "pip", "direct", "none"
|
||||||
|
# Override with environment variable CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE
|
||||||
# "pip" (default): would automatically detect the cuda version, and supply pip with the correct
|
# "pip" (default): would automatically detect the cuda version, and supply pip with the correct
|
||||||
# extra-index-url, based on pytorch.org tables
|
# extra-index-url, based on pytorch.org tables
|
||||||
# "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository
|
# "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository
|
||||||
# and matching the automatically detected cuda version with the required pytorch wheel.
|
# and matching the automatically detected cuda version with the required pytorch wheel.
|
||||||
# if the exact cuda version is not found for the required pytorch wheel, it will try
|
# if the exact cuda version is not found for the required pytorch wheel, it will try
|
||||||
# a lower cuda version until a match is found
|
# a lower cuda version until a match is found
|
||||||
#
|
# "none": No resolver used, install pytorch like any other package
|
||||||
# pytorch_resolve: "pip"
|
# pytorch_resolve: "pip"
|
||||||
|
|
||||||
# additional conda channels to use when installing with conda package manager
|
# additional conda channels to use when installing with conda package manager
|
||||||
|
@ -238,6 +238,8 @@ ENV_CUSTOM_BUILD_SCRIPT = EnvironmentConfig("CLEARML_AGENT_CUSTOM_BUILD_SCRIPT")
|
|||||||
standard flow.
|
standard flow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ENV_PACKAGE_PYTORCH_RESOLVE = EnvironmentConfig("CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE")
|
||||||
|
|
||||||
|
|
||||||
class FileBuffering(IntEnum):
|
class FileBuffering(IntEnum):
|
||||||
"""
|
"""
|
||||||
|
@ -16,6 +16,7 @@ import six
|
|||||||
from .requirements import (
|
from .requirements import (
|
||||||
SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement,
|
SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement,
|
||||||
compare_version_rules, )
|
compare_version_rules, )
|
||||||
|
from ...definitions import ENV_PACKAGE_PYTORCH_RESOLVE
|
||||||
from ...external.requirements_parser.requirement import Requirement
|
from ...external.requirements_parser.requirement import Requirement
|
||||||
|
|
||||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||||
@ -174,6 +175,7 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
extra_index_url_template = 'https://download.pytorch.org/whl/cu{}/'
|
extra_index_url_template = 'https://download.pytorch.org/whl/cu{}/'
|
||||||
nightly_extra_index_url_template = 'https://download.pytorch.org/whl/nightly/cu{}/'
|
nightly_extra_index_url_template = 'https://download.pytorch.org/whl/nightly/cu{}/'
|
||||||
torch_index_url_lookup = {}
|
torch_index_url_lookup = {}
|
||||||
|
resolver_types = ("pip", "direct", "none")
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os_name = kwargs.pop("os_override", None)
|
os_name = kwargs.pop("os_override", None)
|
||||||
@ -208,6 +210,13 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
if self.config.get("agent.package_manager.torch_url_template", None):
|
if self.config.get("agent.package_manager.torch_url_template", None):
|
||||||
PytorchWheel.url_template = \
|
PytorchWheel.url_template = \
|
||||||
self.config.get("agent.package_manager.torch_url_template", None)
|
self.config.get("agent.package_manager.torch_url_template", None)
|
||||||
|
self.resolve_algorithm = str(
|
||||||
|
ENV_PACKAGE_PYTORCH_RESOLVE.get() or
|
||||||
|
self.config.get("agent.package_manager.pytorch_resolve", "pip")).lower()
|
||||||
|
if self.resolve_algorithm not in self.resolver_types:
|
||||||
|
print("WARNING: agent.package_manager.pytorch_resolve=={} not in {} reverting to '{}'".format(
|
||||||
|
self.resolve_algorithm, self.resolver_types, self.resolver_types[0]))
|
||||||
|
self.resolve_algorithm = self.resolver_types[0]
|
||||||
|
|
||||||
def _init_python_ver_cuda_ver(self):
|
def _init_python_ver_cuda_ver(self):
|
||||||
if self.cuda is None:
|
if self.cuda is None:
|
||||||
@ -261,6 +270,10 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def match(self, req):
|
def match(self, req):
|
||||||
|
if self.resolve_algorithm == "none":
|
||||||
|
# skipping resolver
|
||||||
|
return False
|
||||||
|
|
||||||
return req.name in self.packages
|
return req.name in self.packages
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -347,8 +360,10 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
from pip._internal.commands.show import search_packages_info
|
from pip._internal.commands.show import search_packages_info
|
||||||
installed_torch = list(search_packages_info([req.name]))
|
installed_torch = list(search_packages_info([req.name]))
|
||||||
# notice the comparison order, the first part will make sure we have a valid installed package
|
# notice the comparison order, the first part will make sure we have a valid installed package
|
||||||
installed_torch_version = (getattr(installed_torch[0], 'version', None) or installed_torch[0]['version']) \
|
installed_torch_version = \
|
||||||
if installed_torch else None
|
(getattr(installed_torch[0], 'version', None) or
|
||||||
|
installed_torch[0]['version']) if installed_torch else None
|
||||||
|
|
||||||
if installed_torch and installed_torch_version and \
|
if installed_torch and installed_torch_version and \
|
||||||
req.compare_version(installed_torch_version):
|
req.compare_version(installed_torch_version):
|
||||||
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
print('PyTorch: requested "{}" version {}, using pre-installed version {}'.format(
|
||||||
@ -356,6 +371,7 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
# package already installed, do nothing
|
# package already installed, do nothing
|
||||||
req.specs = [('==', str(installed_torch_version))]
|
req.specs = [('==', str(installed_torch_version))]
|
||||||
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
|
return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -480,8 +496,11 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
# we first try to resolve things ourselves because pytorch pip is not always picking the correct
|
# we first try to resolve things ourselves because pytorch pip is not always picking the correct
|
||||||
# versions from their pip repository
|
# versions from their pip repository
|
||||||
|
|
||||||
resolve_algorithm = str(self.config.get("agent.package_manager.pytorch_resolve", "pip")).lower()
|
resolve_algorithm = self.resolve_algorithm
|
||||||
if resolve_algorithm == "direct":
|
if resolve_algorithm == "none":
|
||||||
|
# skipping resolver
|
||||||
|
return None
|
||||||
|
elif resolve_algorithm == "direct":
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
new_req = self._replace(req)
|
new_req = self._replace(req)
|
||||||
@ -489,8 +508,8 @@ class PytorchRequirement(SimpleSubstitution):
|
|||||||
self._original_req.append((req, new_req))
|
self._original_req.append((req, new_req))
|
||||||
return new_req
|
return new_req
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
print("Warning: Failed resolving using `pytorch_resolve=direct` reverting to `pytorch_resolve=pip`")
|
||||||
elif resolve_algorithm not in ("direct", "pip"):
|
elif resolve_algorithm not in self.resolver_types:
|
||||||
print("Warning: `agent.package_manager.pytorch_resolve={}` "
|
print("Warning: `agent.package_manager.pytorch_resolve={}` "
|
||||||
"unrecognized, default to `pip`".format(resolve_algorithm))
|
"unrecognized, default to `pip`".format(resolve_algorithm))
|
||||||
|
|
||||||
|
@ -96,14 +96,15 @@ agent {
|
|||||||
# additional flags to use when calling pip install, example: ["--use-deprecated=legacy-resolver", ]
|
# additional flags to use when calling pip install, example: ["--use-deprecated=legacy-resolver", ]
|
||||||
# extra_pip_install_flags: []
|
# extra_pip_install_flags: []
|
||||||
|
|
||||||
# control the pytorch wheel resolving algorithm, options are: "pip", "direct"
|
# control the pytorch wheel resolving algorithm, options are: "pip", "direct", "none"
|
||||||
|
# Override with environment variable CLEARML_AGENT_PACKAGE_PYTORCH_RESOLVE
|
||||||
# "pip" (default): would automatically detect the cuda version, and supply pip with the correct
|
# "pip" (default): would automatically detect the cuda version, and supply pip with the correct
|
||||||
# extra-index-url, based on pytorch.org tables
|
# extra-index-url, based on pytorch.org tables
|
||||||
# "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository
|
# "direct": would resolve a direct link to the pytorch wheel by parsing the pytorch.org pip repository
|
||||||
# and matching the automatically detected cuda version with the required pytorch wheel.
|
# and matching the automatically detected cuda version with the required pytorch wheel.
|
||||||
# if the exact cuda version is not found for the required pytorch wheel, it will try
|
# if the exact cuda version is not found for the required pytorch wheel, it will try
|
||||||
# a lower cuda version until a match is found
|
# a lower cuda version until a match is found
|
||||||
#
|
# "none": No resolver used, install pytorch like any other package
|
||||||
# pytorch_resolve: "pip"
|
# pytorch_resolve: "pip"
|
||||||
|
|
||||||
# additional conda channels to use when installing with conda package manager
|
# additional conda channels to use when installing with conda package manager
|
||||||
|
Loading…
Reference in New Issue
Block a user