Fix conda pip freeze to be consistent with trains 0.16.3

This commit is contained in:
allegroai 2020-10-11 11:25:35 +03:00
parent 9fe77f3c28
commit a2156e73bf

View File

@ -227,37 +227,61 @@ class CondaAPI(PackageManager):
with self.temp_file("pip_reqs", pip_packages) as reqs: with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs) self.pip.install_from_file(reqs)
def freeze(self): def freeze(self, freeze_full_environment=False):
requirements = self.pip.freeze() requirements = self.pip.freeze()
req_lines = []
conda_lines = []
# noinspection PyBroadException # noinspection PyBroadException
try: try:
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True)) pip_lines = requirements['pip']
conda_packages_txt = [] conda_packages_json = json.loads(
requirements_pip = [r.split('==', 1)[0].split('@', 1)[0].strip().lower() for r in requirements['pip']] self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
for pkg in conda_packages: for r in conda_packages_json:
# skip if this is a pypi package or it is not a python package at all # check if this is a pypi package, if it is, leave it outside
# if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip: if not r.get('channel') or r.get('channel') == 'pypi':
# continue name = (r['name'].replace('-', '_'), r['name'])
pip_req_line = [l for l in pip_lines
if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
if pip_req_line and \
('@' not in pip_req_line[0] or
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
req_lines.append(pip_req_line[0])
continue
# skip if this is a pypi package or name starts with _ (internal conda) req_lines.append(
if pkg['channel'] == 'pypi' or pkg['name'].strip().startswith('_'): '{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
continue continue
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
requirements['conda'] = conda_packages_txt # check if we have it in our required packages
name = r['name']
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
# skip over packages with _
if name.startswith('_'):
continue
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
# make sure we see the conda packages, put them into the pip as well
if conda_lines:
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
requirements['pip'] = req_lines
requirements['conda'] = conda_lines
except Exception: except Exception:
pass pass
# noinspection PyBroadException if freeze_full_environment:
try: # noinspection PyBroadException
conda_env_json = json.loads( try:
self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True)) conda_env_json = json.loads(
conda_env_json.pop('name', None) self._run_command((self.conda, "env", "export", "--json", "-p", self.path), raw=True))
conda_env_json.pop('prefix', None) conda_env_json.pop('name', None)
conda_env_json.pop('channels', None) conda_env_json.pop('prefix', None)
requirements['conda_env_json'] = json.dumps(conda_env_json) conda_env_json.pop('channels', None)
except Exception: requirements['conda_env_json'] = json.dumps(conda_env_json)
pass except Exception:
pass
return requirements return requirements
@ -457,7 +481,7 @@ class CondaAPI(PackageManager):
# conform conda packages (version/name) # conform conda packages (version/name)
for r in reqs: for r in reqs:
# change _ to - in name but not the prefix _ (as this is conda prefix) # change _ to - in name but not the prefix _ (as this is conda prefix)
if not r.name.startswith('_'): if not r.name.startswith('_') and not requirements.get('conda', None):
r.name = r.name.replace('_', '-') r.name = r.name.replace('_', '-')
# remove .post from version numbers, it fails ~= version, and change == to ~= # remove .post from version numbers, it fails ~= version, and change == to ~=
if r.specs and r.specs[0]: if r.specs and r.specs[0]: