mirror of
https://github.com/clearml/clearml-agent
synced 2025-05-11 15:21:05 +00:00
Fix conda environment support for trains 0.16.3 full env. Add agent.package_manager.conda_full_env_update to allow conda to update back the requirements (default is false, to preserve previous behavior)
This commit is contained in:
parent
6f078afafd
commit
9fe77f3c28
@ -62,6 +62,7 @@ agent {
|
|||||||
|
|
||||||
# additional conda channels to use when installing with conda package manager
|
# additional conda channels to use when installing with conda package manager
|
||||||
conda_channels: ["pytorch", "conda-forge", ]
|
conda_channels: ["pytorch", "conda-forge", ]
|
||||||
|
# conda_full_env_update: false
|
||||||
|
|
||||||
# set the priority packages to be installed before the rest of the required packages
|
# set the priority packages to be installed before the rest of the required packages
|
||||||
# priority_packages: ["cython", "numpy", "setuptools", ]
|
# priority_packages: ["cython", "numpy", "setuptools", ]
|
||||||
|
@ -1452,7 +1452,10 @@ class Worker(ServiceCommandSection):
|
|||||||
|
|
||||||
# do not update the task packages if we are using conda,
|
# do not update the task packages if we are using conda,
|
||||||
# it will most likely make the task environment unreproducible
|
# it will most likely make the task environment unreproducible
|
||||||
freeze = self.freeze_task_environment(current_task.id if not self.is_conda else None,
|
skip_freeze_update = self.is_conda and not self._session.config.get(
|
||||||
|
"agent.package_manager.conda_full_env_update", False)
|
||||||
|
|
||||||
|
freeze = self.freeze_task_environment(current_task.id if not skip_freeze_update else None,
|
||||||
requirements_manager=requirements_manager)
|
requirements_manager=requirements_manager)
|
||||||
script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix()
|
script_dir = (directory if isinstance(directory, Path) else Path(directory)).absolute().as_posix()
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from collections import OrderedDict
|
||||||
from distutils.spawn import find_executable
|
from distutils.spawn import find_executable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
@ -228,21 +229,126 @@ class CondaAPI(PackageManager):
|
|||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
requirements = self.pip.freeze()
|
requirements = self.pip.freeze()
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
|
||||||
conda_packages_txt = []
|
conda_packages_txt = []
|
||||||
requirements_pip = [r.split('==')[0].strip().lower() for r in requirements['pip']]
|
requirements_pip = [r.split('==', 1)[0].split('@', 1)[0].strip().lower() for r in requirements['pip']]
|
||||||
for pkg in conda_packages:
|
for pkg in conda_packages:
|
||||||
# skip if this is a pypi package or it is not a python package at all
|
# skip if this is a pypi package or it is not a python package at all
|
||||||
if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
|
# if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# skip if this is a pypi package or name starts with _ (internal conda)
|
||||||
|
if pkg['channel'] == 'pypi' or pkg['name'].strip().startswith('_'):
|
||||||
continue
|
continue
|
||||||
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
|
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
|
||||||
requirements['conda'] = conda_packages_txt
|
requirements['conda'] = conda_packages_txt
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
conda_env_json = json.loads(
|
||||||
|
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
|
||||||
|
conda_env_json.pop('name', None)
|
||||||
|
conda_env_json.pop('prefix', None)
|
||||||
|
conda_env_json.pop('channels', None)
|
||||||
|
requirements['conda_env_json'] = json.dumps(conda_env_json)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
return requirements
|
return requirements
|
||||||
|
|
||||||
|
def _load_conda_full_env(self, conda_env_dict, requirements):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
cuda_version = int(self.session.config.get('agent.cuda_version', 0))
|
||||||
|
except Exception:
|
||||||
|
cuda_version = 0
|
||||||
|
|
||||||
|
conda_env_dict['channels'] = self.extra_channels
|
||||||
|
if 'dependencies' not in conda_env_dict:
|
||||||
|
conda_env_dict['dependencies'] = []
|
||||||
|
new_dependencies = OrderedDict()
|
||||||
|
pip_requirements = None
|
||||||
|
for line in conda_env_dict['dependencies']:
|
||||||
|
if isinstance(line, dict):
|
||||||
|
pip_requirements = line.pop('pip', None)
|
||||||
|
continue
|
||||||
|
name = line.strip().split('=', 1)[0].lower()
|
||||||
|
if name == 'pip':
|
||||||
|
continue
|
||||||
|
elif name == 'python':
|
||||||
|
line = 'python={}'.format('.'.join(line.split('=')[1].split('.')[:2]))
|
||||||
|
elif name == 'tensorflow-gpu' and cuda_version == 0:
|
||||||
|
line = 'tensorflow={}'.format(line.split('=')[1])
|
||||||
|
elif name == 'tensorflow' and cuda_version > 0:
|
||||||
|
line = 'tensorflow-gpu={}'.format(line.split('=')[1])
|
||||||
|
elif name in ('cudatoolkit', 'cupti', 'cudnn'):
|
||||||
|
continue
|
||||||
|
elif name.startswith('_'):
|
||||||
|
continue
|
||||||
|
new_dependencies[line.split('=', 1)[0].strip()] = line
|
||||||
|
|
||||||
|
# fix packages:
|
||||||
|
conda_env_dict['dependencies'] = list(new_dependencies.values())
|
||||||
|
|
||||||
|
with self.temp_file("conda_env", yaml.dump(conda_env_dict), suffix=".yml") as name:
|
||||||
|
print('Conda: Trying to install requirements:\n{}'.format(conda_env_dict['dependencies']))
|
||||||
|
result = self._run_command(
|
||||||
|
("env", "update", "-p", self.path, "--file", name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if we need to remove specific packages
|
||||||
|
bad_req = self._parse_conda_result_bad_packges(result)
|
||||||
|
if bad_req:
|
||||||
|
print('failed installing the following conda packages: {}'.format(bad_req))
|
||||||
|
return False
|
||||||
|
|
||||||
|
if pip_requirements:
|
||||||
|
# create a list of vcs packages that we need to replace in the pip section
|
||||||
|
vcs_reqs = {}
|
||||||
|
if 'pip' in requirements:
|
||||||
|
pip_lines = requirements['pip'].splitlines() \
|
||||||
|
if isinstance(requirements['pip'], six.string_types) else requirements['pip']
|
||||||
|
for line in pip_lines:
|
||||||
|
try:
|
||||||
|
marker = list(parse(line))
|
||||||
|
except Exception:
|
||||||
|
marker = None
|
||||||
|
if not marker:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = MarkerRequirement(marker[0])
|
||||||
|
if m.vcs:
|
||||||
|
vcs_reqs[m.name] = m
|
||||||
|
try:
|
||||||
|
pip_req_str = [str(vcs_reqs.get(r.split('=', 1)[0], r)) for r in pip_requirements
|
||||||
|
if not r.startswith('pip=') and not r.startswith('virtualenv=')]
|
||||||
|
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||||
|
PackageManager._selected_manager = self.pip
|
||||||
|
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
PackageManager._selected_manager = self
|
||||||
|
|
||||||
|
self.requirements_manager.post_install(self.session)
|
||||||
|
|
||||||
def load_requirements(self, requirements):
|
def load_requirements(self, requirements):
|
||||||
|
# if we have a full conda environment, use it and pass the pip to pip
|
||||||
|
if requirements.get('conda_env_json'):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
conda_env_json = json.loads(requirements.get('conda_env_json'))
|
||||||
|
print('Conda restoring full yaml environment')
|
||||||
|
return self._load_conda_full_env(conda_env_json, requirements)
|
||||||
|
except Exception:
|
||||||
|
print('Could not load fully stored conda environment, falling back to requirements')
|
||||||
|
|
||||||
# create new environment file
|
# create new environment file
|
||||||
conda_env = dict()
|
conda_env = dict()
|
||||||
conda_env['channels'] = self.extra_channels
|
conda_env['channels'] = self.extra_channels
|
||||||
@ -276,6 +382,15 @@ class CondaAPI(PackageManager):
|
|||||||
if m.vcs:
|
if m.vcs:
|
||||||
pip_requirements.append(m)
|
pip_requirements.append(m)
|
||||||
continue
|
continue
|
||||||
|
# Skip over pip
|
||||||
|
if m.name in ('pip', 'virtualenv', ):
|
||||||
|
continue
|
||||||
|
# python version, only major.minor
|
||||||
|
if m.name == 'python' and m.specs:
|
||||||
|
m.specs = [(m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])), ]
|
||||||
|
if '.' not in m.specs[0][1]:
|
||||||
|
continue
|
||||||
|
|
||||||
conda_supported_req_names.append(m.name.lower())
|
conda_supported_req_names.append(m.name.lower())
|
||||||
if m.req.name.lower() == 'matplotlib':
|
if m.req.name.lower() == 'matplotlib':
|
||||||
has_matplotlib = True
|
has_matplotlib = True
|
||||||
@ -303,15 +418,20 @@ class CondaAPI(PackageManager):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
m = MarkerRequirement(marker[0])
|
m = MarkerRequirement(marker[0])
|
||||||
|
# skip over local files (we cannot change the version to a local file)
|
||||||
|
if m.local_file:
|
||||||
|
continue
|
||||||
m_name = m.name.lower()
|
m_name = m.name.lower()
|
||||||
if m_name in conda_supported_req_names:
|
if m_name in conda_supported_req_names:
|
||||||
# this package is in the conda list,
|
# this package is in the conda list,
|
||||||
# make sure that if we changed version and we match it in conda
|
# make sure that if we changed version and we match it in conda
|
||||||
conda_supported_req_names.remove(m_name)
|
## conda_supported_req_names.remove(m_name)
|
||||||
for cr in reqs:
|
for cr in reqs:
|
||||||
if m_name == cr.name.lower():
|
if m_name.lower().replace('_', '-') == cr.name.lower().replace('_', '-'):
|
||||||
# match versions
|
# match versions
|
||||||
cr.specs = m.specs
|
cr.specs = m.specs
|
||||||
|
# # conda always likes "-" not "_" but only on pypi packages
|
||||||
|
# cr.name = cr.name.lower().replace('_', '-')
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# not in conda, it is a pip package
|
# not in conda, it is a pip package
|
||||||
@ -319,29 +439,38 @@ class CondaAPI(PackageManager):
|
|||||||
if m_name == 'matplotlib':
|
if m_name == 'matplotlib':
|
||||||
has_matplotlib = True
|
has_matplotlib = True
|
||||||
|
|
||||||
# remove any leftover conda packages (they were removed from the pip list)
|
|
||||||
if conda_supported_req_names:
|
|
||||||
reqs = [r for r in reqs if r.name.lower() not in conda_supported_req_names]
|
|
||||||
|
|
||||||
# Conda requirements Hacks:
|
# Conda requirements Hacks:
|
||||||
if has_matplotlib:
|
if has_matplotlib:
|
||||||
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
|
||||||
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
|
reqs.append(MarkerRequirement(Requirement.parse('python-graphviz')))
|
||||||
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))
|
||||||
|
|
||||||
|
# remove specific cudatoolkit, it should have being preinstalled.
|
||||||
|
reqs = [r for r in reqs if r.name not in ('cudatoolkit', 'cudnn', 'cupti')]
|
||||||
|
|
||||||
if has_torch and cuda_version == 0:
|
if has_torch and cuda_version == 0:
|
||||||
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
|
||||||
|
|
||||||
|
# make sure we have no double entries
|
||||||
|
reqs = list(OrderedDict((r.name, r) for r in reqs).values())
|
||||||
|
|
||||||
# conform conda packages (version/name)
|
# conform conda packages (version/name)
|
||||||
for r in reqs:
|
for r in reqs:
|
||||||
|
# change _ to - in name but not the prefix _ (as this is conda prefix)
|
||||||
|
if not r.name.startswith('_'):
|
||||||
|
r.name = r.name.replace('_', '-')
|
||||||
# remove .post from version numbers, it fails ~= version, and change == to ~=
|
# remove .post from version numbers, it fails ~= version, and change == to ~=
|
||||||
if r.specs and r.specs[0]:
|
if r.specs and r.specs[0]:
|
||||||
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
|
r.specs = [(r.specs[0][0].replace('==', '~='), r.specs[0][1].split('.post')[0])]
|
||||||
# conda always likes "-" not "_"
|
|
||||||
r.req.name = r.req.name.replace('_', '-')
|
|
||||||
|
|
||||||
while reqs:
|
while reqs:
|
||||||
# notice, we give conda more freedom in version selection, to help it choose best combination
|
# notice, we give conda more freedom in version selection, to help it choose best combination
|
||||||
conda_env['dependencies'] = [r.tostr() for r in reqs]
|
def clean_ver(ar):
|
||||||
|
if not ar.specs:
|
||||||
|
return ar.tostr()
|
||||||
|
ar.specs = [(ar.specs[0][0], ar.specs[0][1] + '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])]
|
||||||
|
return ar.tostr()
|
||||||
|
conda_env['dependencies'] = [clean_ver(r) for r in reqs]
|
||||||
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
|
||||||
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
|
||||||
result = self._run_command(
|
result = self._run_command(
|
||||||
@ -371,12 +500,15 @@ class CondaAPI(PackageManager):
|
|||||||
|
|
||||||
if pip_requirements:
|
if pip_requirements:
|
||||||
try:
|
try:
|
||||||
pip_req_str = [r.tostr() for r in pip_requirements]
|
pip_req_str = [r.tostr() for r in pip_requirements if r.name not in ('pip', 'virtualenv', )]
|
||||||
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
print('Conda: Installing requirements: step 2 - using pip:\n{}'.format(pip_req_str))
|
||||||
self.pip.load_requirements('\n'.join(pip_req_str))
|
PackageManager._selected_manager = self.pip
|
||||||
|
self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
|
finally:
|
||||||
|
PackageManager._selected_manager = self
|
||||||
|
|
||||||
self.requirements_manager.post_install(self.session)
|
self.requirements_manager.post_install(self.session)
|
||||||
return True
|
return True
|
||||||
@ -442,7 +574,7 @@ class CondaAPI(PackageManager):
|
|||||||
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
return CommandSequence(self.source, self.pip.get_python_command(extra=extra))
|
||||||
|
|
||||||
|
|
||||||
# enable hashing with cmp=False because pdb fails on unhashable exceptions
|
# enable hashing with cmp=False because pdb fails on un-hashable exceptions
|
||||||
exception = attrs(str=True, cmp=False)
|
exception = attrs(str=True, cmp=False)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user