Fix conda pip freeze to be consistent with trains 0.16.3

This commit is contained in:
allegroai 2020-10-11 11:25:35 +03:00
parent 9fe77f3c28
commit a2156e73bf

View File

@ -227,37 +227,61 @@ class CondaAPI(PackageManager):
with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs)
def freeze(self):
def freeze(self, freeze_full_environment=False):
requirements = self.pip.freeze()
req_lines = []
conda_lines = []
# noinspection PyBroadException
try:
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
conda_packages_txt = []
requirements_pip = [r.split('==', 1)[0].split('@', 1)[0].strip().lower() for r in requirements['pip']]
for pkg in conda_packages:
# skip if this is a pypi package or it is not a python package at all
# if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip:
# continue
pip_lines = requirements['pip']
conda_packages_json = json.loads(
self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
for r in conda_packages_json:
# check if this is a pypi package, if it is, leave it outside
if not r.get('channel') or r.get('channel') == 'pypi':
name = (r['name'].replace('-', '_'), r['name'])
pip_req_line = [l for l in pip_lines
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
if pip_req_line and \
('@' not in pip_req_line[0] or
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
req_lines.append(pip_req_line[0])
continue
# skip if this is a pypi package or name starts with _ (internal conda)
if pkg['channel'] == 'pypi' or pkg['name'].strip().startswith('_'):
req_lines.append(
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
continue
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
requirements['conda'] = conda_packages_txt
# check if we have it in our required packages
name = r['name']
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
# skip over packages with _
if name.startswith('_'):
continue
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
# make sure we see the conda packages, put them into the pip as well
if conda_lines:
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
requirements['pip'] = req_lines
requirements['conda'] = conda_lines
except Exception:
pass
# noinspection PyBroadException
try:
conda_env_json = json.loads(
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
conda_env_json.pop('name', None)
conda_env_json.pop('prefix', None)
conda_env_json.pop('channels', None)
requirements['conda_env_json'] = json.dumps(conda_env_json)
except Exception:
pass
if freeze_full_environment:
# noinspection PyBroadException
try:
conda_env_json = json.loads(
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
conda_env_json.pop('name', None)
conda_env_json.pop('prefix', None)
conda_env_json.pop('channels', None)
requirements['conda_env_json'] = json.dumps(conda_env_json)
except Exception:
pass
return requirements
@ -457,7 +481,7 @@ class CondaAPI(PackageManager):
# conform conda packages (version/name)
for r in reqs:
# change _ to - in name but not the prefix _ (as this is conda prefix)
if not r.name.startswith('_'):
if not r.name.startswith('_') and not requirements.get('conda', None):
r.name = r.name.replace('_', '-')
# remove .post from version numbers, it fails ~= version, and change == to ~=
if r.specs and r.specs[0]: