From 5640489f57a53d3fdf5ff12664b18bb25e576326 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 4 Oct 2020 19:40:39 +0300 Subject: [PATCH] Replace torch version on pre-installed local file --- trains_agent/helper/package/pytorch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index 22ad91b..7150b25 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -281,7 +281,7 @@ class PytorchRequirement(SimpleSubstitution): req.name, req.specs[0] if req.specs else 'unspecified', installed_torch[0]['version'])) # package already installed, do nothing req.specs = [('==', str(installed_torch[0]['version']))] - return str(req), True + return '{} {} {}'.format(req.name, req.specs[0][0], req.specs[0][1]), True except Exception: pass @@ -463,7 +463,11 @@ class PytorchRequirement(SimpleSubstitution): if '@' in line: # skip if we have nothing to add if str(req).strip() != str(new_req).strip(): - lines[i] = '{} # {}'.format(str(req), str(new_req)) + # if this is local file and use the version detection + if req.local_file: + lines[i] = '{}'.format(str(new_req)) + else: + lines[i] = '{} # {}'.format(str(req), str(new_req)) else: lines[i] = '{} # {}'.format(line, str(new_req)) break