Add conda support

This commit is contained in:
allegroai 2020-01-21 16:21:18 +02:00
parent b6e04ab982
commit 599219b02d
2 changed files with 51 additions and 24 deletions

View File

@ -1151,7 +1151,8 @@ class Worker(ServiceCommandSection):
self._update_commit_id(task_id, execution, repo_info) self._update_commit_id(task_id, execution, repo_info)
# Add the script CWD to the python path # Add the script CWD to the python path
python_path = get_python_path(script_dir, execution.entry_point, self.package_api) python_path = get_python_path(script_dir, execution.entry_point, self.package_api) \
if not self.is_conda else None
if python_path: if python_path:
os.environ['PYTHONPATH'] = python_path os.environ['PYTHONPATH'] = python_path
@ -1631,14 +1632,20 @@ class Worker(ServiceCommandSection):
requested_python_version = requested_python_version or \ requested_python_version = requested_python_version or \
Text(self._session.config.get("agent.python_binary", None)) or \ Text(self._session.config.get("agent.python_binary", None)) or \
Text(self._session.config.get("agent.default_python", None)) Text(self._session.config.get("agent.default_python", None))
if self.is_conda:
executable_version_suffix = \
requested_python_version[max(requested_python_version.find('python'), 0):].replace('python', '')
executable_name = 'python'
else:
executable_version, executable_version_suffix, executable_name = self.find_python_executable_for_version( executable_version, executable_version_suffix, executable_name = self.find_python_executable_for_version(
requested_python_version requested_python_version
) )
self._session.config.put("agent.default_python", executable_version)
self._session.config.put("agent.python_binary", executable_name)
venv_dir = Path(venv_dir) if venv_dir else \ venv_dir = Path(venv_dir) if venv_dir else \
Path(self._session.config["agent.venvs_dir"], executable_version_suffix) Path(self._session.config["agent.venvs_dir"], executable_version_suffix)
self._session.config.put("agent.default_python", executable_version)
self._session.config.put("agent.python_binary", executable_name)
first_time = not standalone_mode and ( first_time = not standalone_mode and (
is_windows_platform() is_windows_platform()
or self.is_conda or self.is_conda

View File

@ -227,20 +227,20 @@ class CondaAPI(PackageManager):
self.pip.install_from_file(reqs) self.pip.install_from_file(reqs)
def freeze(self): def freeze(self):
# result = yaml.load( requirements = self.pip.freeze()
# self._run_command((self.conda, "env", "export", "-p", self.path), raw=True) try:
# ) conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
# for key in "name", "prefix": conda_packages_txt = []
# result.pop(key, None) requirements_pip = [r.split('==')[0].strip().lower() for r in requirements['pip']]
# freeze = {"conda": result} for pkg in conda_packages:
# try: # skip if this is a pypi package or it is not a python package at all
# freeze["pip"] = result["dependencies"][-1]["pip"] if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
# except (TypeError, KeyError): continue
# freeze["pip"] = [] conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
# else: requirements['conda'] = conda_packages_txt
# del result["dependencies"][-1] except:
# return freeze pass
return self.pip.freeze() return requirements
def load_requirements(self, requirements): def load_requirements(self, requirements):
# create new environment file # create new environment file
@ -249,6 +249,8 @@ class CondaAPI(PackageManager):
reqs = [] reqs = []
if isinstance(requirements['pip'], six.string_types): if isinstance(requirements['pip'], six.string_types):
requirements['pip'] = requirements['pip'].split('\n') requirements['pip'] = requirements['pip'].split('\n')
if isinstance(requirements.get('conda'), six.string_types):
requirements['conda'] = requirements['conda'].split('\n')
has_torch = False has_torch = False
has_matplotlib = False has_matplotlib = False
try: try:
@ -256,10 +258,15 @@ class CondaAPI(PackageManager):
except: except:
cuda_version = 0 cuda_version = 0
for r in requirements['pip']: # notice 'conda' entry with empty string is a valid conda requirements list, it means pip only
# this should happen if experiment was executed on non-conda machine or old trains client
conda_supported_req = requirements['pip'] if requirements.get('conda', None) is None else requirements['conda']
conda_supported_req_names = []
for r in conda_supported_req:
marker = list(parse(r)) marker = list(parse(r))
if marker: if marker:
m = MarkerRequirement(marker[0]) m = MarkerRequirement(marker[0])
conda_supported_req_names.append(m.name.lower())
if m.req.name.lower() == 'matplotlib': if m.req.name.lower() == 'matplotlib':
has_matplotlib = True has_matplotlib = True
elif m.req.name.lower().startswith('torch'): elif m.req.name.lower().startswith('torch'):
@ -273,8 +280,20 @@ class CondaAPI(PackageManager):
has_torch = True has_torch = True
m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow' m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'
# conda always likes - not _
m.req.name = m.req.name.replace('_', '-')
reqs.append(m) reqs.append(m)
pip_requirements = [] pip_requirements = []
# if we have a conda list, the rest should be installed with pip,
if requirements.get('conda', None) is not None:
for r in requirements['pip']:
marker = list(parse(r))
if marker:
m = MarkerRequirement(marker[0])
if m.name.lower() not in conda_supported_req_names:
pip_requirements.append(m)
# Conda requirements Hacks: # Conda requirements Hacks:
if has_matplotlib: if has_matplotlib:
@ -284,7 +303,8 @@ class CondaAPI(PackageManager):
reqs.append(MarkerRequirement(Requirement.parse('cpuonly'))) reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))
while reqs: while reqs:
conda_env['dependencies'] = [r.tostr().replace('==', '=') for r in reqs] # notice, we give conda more freedom in version selection, to help it choose best combination
conda_env['dependencies'] = [r.tostr().replace('==', '~=') for r in reqs]
with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name: with self.temp_file("conda_env", yaml.dump(conda_env), suffix=".yml") as name:
print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies'])) print('Conda: Trying to install requirements:\n{}'.format(conda_env['dependencies']))
result = self._run_command( result = self._run_command(
@ -338,7 +358,7 @@ class CondaAPI(PackageManager):
if len(empty_lines) >= 2: if len(empty_lines) >= 2:
deps = error_lines[empty_lines[0]+1:empty_lines[1]] deps = error_lines[empty_lines[0]+1:empty_lines[1]]
try: try:
return yaml.load('\n'.join(deps)) return yaml.load('\n'.join(deps), Loader=yaml.SafeLoader)
except: except:
return None return None
return None return None
@ -412,4 +432,4 @@ class PackageNotFoundError(CondaException):
as a singleton YAML list. as a singleton YAML list.
""" """
pkg = attrib(default="", converter=lambda val: yaml.load(val)[0].replace(" ", "")) pkg = attrib(default="", converter=lambda val: yaml.load(val, Loader=yaml.SafeLoader)[0].replace(" ", ""))