Allow to specifying cudatoolkit version in "installed packages" when using Conda as package manager (trains issue #229)

This commit is contained in:
allegroai 2020-10-30 10:06:02 +02:00
parent 293a92f486
commit 216b3e2179

View File

@ -378,7 +378,8 @@ class CondaAPI(PackageManager):
line = 'tensorflow={}'.format(line.split('=')[1])
elif name == 'tensorflow' and cuda_version > 0:
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
elif name in ('cudatoolkit', 'cupti', 'cudnn'):
elif name in ('cupti', 'cudnn'):
# cudatoolkit should pull them based on the cudatoolkit version
continue
elif name.startswith('_'):
continue
@ -543,7 +544,8 @@ class CondaAPI(PackageManager):
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
# remove specific cudatoolkit, it should have being preinstalled.
reqs = [r for r in reqs if r.name not in ('cudatoolkit', 'cudnn', 'cupti')]
# allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]
if has_torch and cuda_version == 0:
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))