Fix conda pip freeze to be consistent with trains 0.16.3

This commit is contained in:
allegroai 2020-10-11 11:25:35 +03:00
parent 9fe77f3c28
commit a2156e73bf

View File

@ -227,27 +227,51 @@ class CondaAPI(PackageManager):
with self.temp_file("pip_reqs", pip_packages) as reqs: with self.temp_file("pip_reqs", pip_packages) as reqs:
self.pip.install_from_file(reqs) self.pip.install_from_file(reqs)
def freeze(self): def freeze(self, freeze_full_environment=False):
requirements = self.pip.freeze() requirements = self.pip.freeze()
req_lines = []
conda_lines = []
# noinspection PyBroadException # noinspection PyBroadException
try: try:
conda_packages = json.loads(self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True)) pip_lines = requirements['pip']
conda_packages_txt = [] conda_packages_json = json.loads(
requirements_pip = [r.split('==', 1)[0].split('@', 1)[0].strip().lower() for r in requirements['pip']] self._run_command((self.conda, "list", "--json", "-p", self.path), raw=True))
for pkg in conda_packages: for r in conda_packages_json:
# skip if this is a pypi package or it is not a python package at all # check if this is a pypi package, if it is, leave it outside
# if pkg['channel'] == 'pypi' or pkg['name'].lower() not in requirements_pip: if not r.get('channel') or r.get('channel') == 'pypi':
# continue name = (r['name'].replace('-', '_'), r['name'])
pip_req_line = [l for l in pip_lines
# skip if this is a pypi package or name starts with _ (internal conda) if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name]
if pkg['channel'] == 'pypi' or pkg['name'].strip().startswith('_'): if pip_req_line and \
('@' not in pip_req_line[0] or
not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')):
req_lines.append(pip_req_line[0])
continue continue
conda_packages_txt.append('{0}{1}{2}'.format(pkg['name'], '==', pkg['version']))
requirements['conda'] = conda_packages_txt req_lines.append(
'{}=={}'.format(name[1], r['version']) if r.get('version') else '{}'.format(name[1]))
continue
# check if we have it in our required packages
name = r['name']
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
# skip over packages with _
if name.startswith('_'):
continue
conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name))
# make sure we see the conda packages, put them into the pip as well
if conda_lines:
req_lines = ['# Conda Packages', ''] + conda_lines + ['', '# pip Packages', ''] + req_lines
requirements['pip'] = req_lines
requirements['conda'] = conda_lines
except Exception: except Exception:
pass pass
if freeze_full_environment:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
conda_env_json = json.loads( conda_env_json = json.loads(
@ -457,7 +481,7 @@ class CondaAPI(PackageManager):
# conform conda packages (version/name) # conform conda packages (version/name)
for r in reqs: for r in reqs:
# change _ to - in name but not the prefix _ (as this is conda prefix) # change _ to - in name but not the prefix _ (as this is conda prefix)
if not r.name.startswith('_'): if not r.name.startswith('_') and not requirements.get('conda', None):
r.name = r.name.replace('_', '-') r.name = r.name.replace('_', '-')
# remove .post from version numbers, it fails ~= version, and change == to ~= # remove .post from version numbers, it fails ~= version, and change == to ~=
if r.specs and r.specs[0]: if r.specs and r.specs[0]: