Replace torch version on pre-installed local file

This commit is contained in:
allegroai 2020-10-04 19:40:39 +03:00
parent 8135a6facf
commit 5640489f57

View File

@ -281,7 +281,7 @@ class PytorchRequirement(SimpleSubstitution):
req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version'])) req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version']))
# package already installed, do nothing # package already installed, do nothing
req.specs = [('==', str(installed_torch[0]['version']))] req.specs = [('==', str(installed_torch[0]['version']))]
return str(req), True return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True
except Exception: except Exception:
pass pass
@ -463,6 +463,10 @@ class PytorchRequirement(SimpleSubstitution):
if '@' in line: if '@' in line:
# skip if we have nothing to add # skip if we have nothing to add
if str(req).strip() != str(new_req).strip(): if str(req).strip() != str(new_req).strip():
# if this is local file and use the version detection
if req.local_file:
lines[i] = '{}'.format(str(new_req))
else:
lines[i] = '{} # {}'.format(str(req), str(new_req)) lines[i] = '{} # {}'.format(str(req), str(new_req))
else: else:
lines[i] = '{} # {}'.format(line, str(new_req)) lines[i] = '{} # {}'.format(line, str(new_req))