diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 22ad91b..7150b25 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -281,7 +281,7 @@ class PytorchRequirement(SimpleSubstitution): req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version'])) # package already installed, do nothing req.specs = [('==', str(installed_torch[0]['version']))] - return str(req), True + return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True except Exception: pass @@ -463,7 +463,11 @@ class PytorchRequirement(SimpleSubstitution): if '@' in line: # skip if we have nothing to add if str(req).strip() != str(new_req).strip(): - lines[i] = '{} # {}'.format(str(req), str(new_req)) + # if this is local file and use the version detection + if req.local_file: + lines[i] = '{}'.format(str(new_req)) + else: + lines[i] = '{} # {}'.format(str(req), str(new_req)) else: lines[i] = '{} # {}'.format(line, str(new_req)) break