diff --git a/trains_agent/helper/package/conda_api.py b/trains_agent/helper/package/conda_api.py index 00ffab4..dc135a9 100644 --- a/trains_agent/helper/package/conda_api.py +++ b/trains_agent/helper/package/conda_api.py @@ -263,37 +263,59 @@ class CondaAPI(PackageManager): conda_supported_req = requirements['pip'] if requirements.get('conda', None) is None else requirements['conda'] conda_supported_req_names = [] for r in conda_supported_req: - marker = list(parse(r)) - if marker: - m = MarkerRequirement(marker[0]) - conda_supported_req_names.append(m.name.lower()) - if m.req.name.lower() == 'matplotlib': - has_matplotlib = True - elif m.req.name.lower().startswith('torch'): - has_torch = True + try: + marker = list(parse(r)) + except: + marker = None + if not marker: + continue - if m.req.name.lower() in ('torch', 'pytorch'): - has_torch = True - m.req.name = 'pytorch' + m = MarkerRequirement(marker[0]) + conda_supported_req_names.append(m.name.lower()) + if m.req.name.lower() == 'matplotlib': + has_matplotlib = True + elif m.req.name.lower().startswith('torch'): + has_torch = True - if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'): - has_torch = True - m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow' + if m.req.name.lower() in ('torch', 'pytorch'): + has_torch = True + m.req.name = 'pytorch' - # conda always likes - not _ - m.req.name = m.req.name.replace('_', '-') + if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu', 'tensorflow'): + has_torch = True + m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow' - reqs.append(m) + reqs.append(m) pip_requirements = [] # if we have a conda list, the rest should be installed with pip, if requirements.get('conda', None) is not None: for r in requirements['pip']: - marker = list(parse(r)) - if marker: - m = MarkerRequirement(marker[0]) - if m.name.lower() not in conda_supported_req_names: - pip_requirements.append(m) + try: + marker = list(parse(r)) + except: + marker = None + if not marker: + continue + + m = MarkerRequirement(marker[0]) + m_name = m.name.lower() + if m_name in conda_supported_req_names: + # this package is in the conda list, + # make sure that if we changed version and we match it in conda + conda_supported_req_names.remove(m_name) + for cr in reqs: + if m_name == cr.name.lower(): + # match versions + cr.specs = m.specs + break + else: + # not in conda, it is a pip package + pip_requirements.append(m) + + # remove any leftover conda packages (they were removed from the pip list) + if conda_supported_req_names: + reqs = [r for r in reqs if r.name.lower() not in conda_supported_req_names] # Conda requirements Hacks: if has_matplotlib: @@ -302,9 +324,17 @@ class CondaAPI(PackageManager): if has_torch and cuda_version == 0: reqs.append(MarkerRequirement(Requirement.parse('cpuonly'))) + # conform conda packages (version/name) + for r in reqs: + # remove .post from version numbers, it fails ~= version, and change == to ~= + if r.specs and r.specs[0]: + r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])] + # conda always likes "-" not "_" + r.req.name = r.req.name.replace('_', '-') + while reqs: # notice, we give conda more freedom in version selection, to help it choose best combination - conda_env['dependencies'] = [r.tostr().replace('==', '~=') for r in reqs] + conda_env['dependencies'] = [r.tostr() for r in reqs] with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name: print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies'])) result = self._run_command( @@ -317,7 +347,7 @@ class CondaAPI(PackageManager): solved = False for bad_r in bad_req: - name = bad_r.split('[')[0].split('=')[0] + name = bad_r.split('[')[0].split('=')[0].split('~')[0].split('<')[0].split('>')[0] # look for name in requirements for r in reqs: if r.name.lower() == name.lower():