diff --git a/trains_agent/helper/package/pytorch.py b/trains_agent/helper/package/pytorch.py index f492cdc..addc606 100644 --- a/trains_agent/helper/package/pytorch.py +++ b/trains_agent/helper/package/pytorch.py @@ -435,6 +435,37 @@ class PytorchRequirement(SimpleSubstitution): self.log.debug('Replacing requirement "%s" with %r', req, result) return result + def replace_back(self, list_of_requirements): # type: (Dict) -> Dict + """ + :param list_of_requirements: {'pip': ['a==1.0', ]} + :return: {'pip': ['a==1.0', ]} + """ + if not self._original_req: + return list_of_requirements + try: + for k, lines in list_of_requirements.items(): + # k is either pip/conda + if k not in ('pip', 'conda'): + continue + for i, line in enumerate(lines): + if not line or line.lstrip().startswith('#'): + continue + parts = [p for p in re.split('\s|=|\.|<|>|~|!|@|#', line) if p] + if not parts: + continue + for req, new_req in self._original_req: + if req.req.name == parts[0]: + # support for pip >= 20.1 + if '@' in line: + lines[i] = '{} # {}'.format(str(req), str(new_req)) + else: + lines[i] = '{} # {}'.format(line, str(new_req)) + break + except: + pass + + return list_of_requirements + MAP = { "windows": { "cuda100": {