From 216b3e21790659467007957d26172698fd74e075 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 30 Oct 2020 10:06:02 +0200 Subject: [PATCH] Allow to specifying cudatoolkit version in "installed packages" when using Conda as package manager (trains issue #229) --- trains_agent/helper/package/conda_api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trains_agent/helper/package/conda_api.py b/trains_agent/helper/package/conda_api.py index c8f4d7c..3b4a684 100644 --- a/trains_agent/helper/package/conda_api.py +++ b/trains_agent/helper/package/conda_api.py @@ -378,7 +378,8 @@ class CondaAPI(PackageManager): line = 'tensorflow={}'.format(line.split('=')[1]) elif name == 'tensorflow' and cuda_version > 0: line = 'tensorflow-gpu={}'.format(line.split('=')[1]) - elif name in ('cudatoolkit', 'cupti', 'cudnn'): + elif name in ('cupti', 'cudnn'): + # cudatoolkit should pull them based on the cudatoolkit version continue elif name.startswith('_'): continue @@ -543,7 +544,8 @@ class CondaAPI(PackageManager): reqs.append(MarkerRequirement(Requirement.parse('kiwisolver'))) # remove specific cudatoolkit, it should have being preinstalled. - reqs = [r for r in reqs if r.name not in ('cudatoolkit', 'cudnn', 'cupti')] + # allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them + reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')] if has_torch and cuda_version == 0: reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))