mirror of
https://github.com/clearml/clearml-agent
synced 2025-02-07 13:26:08 +00:00
Fix broken pytorch setuptools incompatibility (force setuptools < 59 if torch is below 1.11)
This commit is contained in:
parent
5f77cad5ac
commit
6c7a639673
@ -7,13 +7,14 @@ from furl import furl
|
||||
import urllib.parse
|
||||
from operator import itemgetter
|
||||
from html.parser import HTMLParser
|
||||
from typing import Text
|
||||
from typing import Text, Optional
|
||||
|
||||
import attr
|
||||
import requests
|
||||
|
||||
import six
|
||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion
|
||||
from .requirements import SimpleSubstitution, FatalSpecsResolutionError, SimpleVersion, MarkerRequirement
|
||||
from ...external.requirements_parser.requirement import Requirement
|
||||
|
||||
OS_TO_WHEEL_NAME = {"linux": "linux_x86_64", "windows": "win_amd64"}
|
||||
|
||||
@ -179,6 +180,7 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
self.python_version_string = None
|
||||
self.python_major_minor_str = None
|
||||
self.python = None
|
||||
self._fix_setuptools = None
|
||||
self.exceptions = []
|
||||
self._original_req = []
|
||||
|
||||
@ -366,6 +368,10 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
else:
|
||||
print('Trying PyTorch CUDA version {} support'.format(torch_url_key))
|
||||
|
||||
# fix broken pytorch setuptools incompatibility
|
||||
if closest_matched_version and SimpleVersion.compare_versions(closest_matched_version, "<", "1.11.0"):
|
||||
self._fix_setuptools = "setuptools < 59"
|
||||
|
||||
if not url:
|
||||
url = PytorchWheel(
|
||||
torch_version=fix_version(version),
|
||||
@ -528,6 +534,16 @@ class PytorchRequirement(SimpleSubstitution):
|
||||
|
||||
return list_of_requirements
|
||||
|
||||
def post_scan_add_req(self): # type: () -> Optional[MarkerRequirement]
|
||||
"""
|
||||
Allows the RequirementSubstitution to add an extra line/requirements after
|
||||
the initial requirements scan is completed.
|
||||
Called only once per requirements.txt object
|
||||
"""
|
||||
if self._fix_setuptools:
|
||||
return MarkerRequirement(Requirement.parse(self._fix_setuptools))
|
||||
return None
|
||||
|
||||
MAP = {
|
||||
"windows": {
|
||||
"cuda100": {
|
||||
|
@ -628,10 +628,23 @@ class RequirementsManager(object):
|
||||
|
||||
result = list(result)
|
||||
# add post scan add requirements call back
|
||||
double_req_set = None
|
||||
for h in self.handlers:
|
||||
req = h.post_scan_add_req()
|
||||
if req:
|
||||
result.append(req.tostr())
|
||||
reqs = h.post_scan_add_req()
|
||||
if reqs:
|
||||
if double_req_set is None:
|
||||
def safe_parse_name(line):
|
||||
try:
|
||||
return Requirement.parse(line).name
|
||||
except: # noqa
|
||||
return None
|
||||
double_req_set = set([safe_parse_name(r) for r in result if r])
|
||||
|
||||
for r in (reqs if isinstance(reqs, (tuple, list)) else [reqs]):
|
||||
if r and (not r.name or r.name not in double_req_set):
|
||||
result.append(r.tostr())
|
||||
elif r:
|
||||
print("SKIPPING additional auto installed package: \"{}\"".format(r))
|
||||
|
||||
return join_lines(result)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user