From c22eacd3fc5fb07f6802cdd69b911ac15152121c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 12 Oct 2020 11:09:45 +0300 Subject: [PATCH] Add sdk.development.detect_with_conda_freeze (default true) for full conda freeze (requires trains-agent >= 16.2) --- docs/trains.conf | 1 + trains/backend_interface/task/repo/freeze.py | 92 ++++++++++++++------ trains/backend_interface/task/task.py | 8 +- 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/docs/trains.conf b/docs/trains.conf index ba83a3ae..625d7b0c 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -166,6 +166,7 @@ sdk { # If this flag is true (default is false), instead of analyzing the code with Pigar, analyze with `pip freeze` detect_with_pip_freeze: false + detect_with_conda_freeze: false # Log specific environment variables. OS environments are enlisted in the "Environment" section # of the Hyper-Parameters. diff --git a/trains/backend_interface/task/repo/freeze.py b/trains/backend_interface/task/repo/freeze.py index 8961bf06..e5e43d44 100644 --- a/trains/backend_interface/task/repo/freeze.py +++ b/trains/backend_interface/task/repo/freeze.py @@ -1,33 +1,73 @@ import sys +import os +import json from .util import get_command_output -def pip_freeze(): +def pip_freeze(combine_conda_with_pip=False): req_lines = [] local_packages = [] - try: - req_lines = get_command_output([sys.executable, "-m", "pip", "freeze"]).splitlines() - # fix "package @ file://" from pip freeze to "package" - for i, r in enumerate(req_lines): - parts = r.split('@', 1) - if parts and len(parts) == 2 and parts[1].strip().lower().startswith('file://'): - req_lines[i] = parts[0] - local_packages.append((i, parts[0].strip())) - # if we found local packages, at least get their versions (using pip list) - if local_packages: - # noinspection PyBroadException - try: - list_lines = get_command_output( - [sys.executable, "-m", "pip", "list", "--format", "freeze"]).splitlines() - for index, name in local_packages: - line = [r for r in list_lines if r.strip().startswith(name+'==')] - if not line: - continue - line = line[0] - req_lines[index] = line.strip() - except Exception: - pass - except Exception as ex: - print('Failed calling "pip freeze": {}'.format(str(ex))) - return req_lines + conda_lines = [] + conda_prefix = os.environ.get('CONDA_PREFIX') + if conda_prefix and not conda_prefix.endswith(os.path.sep): + conda_prefix += os.path.sep + + if conda_prefix and sys.executable.startswith(conda_prefix): + pip_lines = get_command_output([sys.executable, "-m", "pip", "freeze"]).splitlines() + conda_packages_json = get_command_output(['conda', 'list', '--json']) + conda_packages_json = json.loads(conda_packages_json) + for r in conda_packages_json: + # check if this is a pypi package, if it is, leave it outside + if not r.get('channel') or r.get('channel') == 'pypi': + name = (r['name'].replace('-', '_'), r['name']) + pip_req_line = [l for l in pip_lines + if l.split('==', 1)[0].strip() in name or l.split('@', 1)[0].strip() in name] + if pip_req_line and \ + ('@' not in pip_req_line[0] or + not pip_req_line[0].split('@', 1)[1].strip().startswith('file://')): + req_lines.append(pip_req_line[0]) + continue + + req_lines.append('{}=={}'.format(name[0], r['version']) if r.get('version') else '{}'.format(name[0])) + continue + + # check if we have it in our required packages + name = r['name'] + # hack support pytorch/torch different naming convention + if name == 'pytorch': + name = 'torch' + # skip over packages with _ + if name.startswith('_'): + continue + conda_lines.append('{}=={}'.format(name, r['version']) if r.get('version') else '{}'.format(name)) + # make sure we see the conda packages, put them into the pip as well + if combine_conda_with_pip and conda_lines: + req_lines += ['', '# Conda Packages', ''] + conda_lines + else: + try: + req_lines = get_command_output([sys.executable, "-m", "pip", "freeze"]).splitlines() + # fix "package @ file://" from pip freeze to "package" + for i, r in enumerate(req_lines): + parts = r.split('@', 1) + if parts and len(parts) == 2 and parts[1].strip().lower().startswith('file://'): + req_lines[i] = parts[0] + local_packages.append((i, parts[0].strip())) + # if we found local packages, at least get their versions (using pip list) + if local_packages: + # noinspection PyBroadException + try: + list_lines = get_command_output( + [sys.executable, "-m", "pip", "list", "--format", "freeze"]).splitlines() + for index, name in local_packages: + line = [r for r in list_lines if r.strip().startswith(name+'==')] + if not line: + continue + line = line[0] + req_lines[index] = line.strip() + except Exception: + pass + except Exception as ex: + print('Failed calling "pip freeze": {}'.format(str(ex))) + + return "\n".join(req_lines), "\n".join(conda_lines) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index f385c7b8..1ce94890 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -309,10 +309,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if result.script and script_requirements: entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \ os.path.join(result.script['working_dir'], entry_point) - if config.get('development.detect_with_pip_freeze', False): - conda_requirements = "" + if config.get('development.detect_with_pip_freeze', False) or \ + config.get('development.detect_with_conda_freeze', False): + requirements, conda_requirements = pip_freeze( + config.get('development.detect_with_conda_freeze', False)) requirements = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n\n'\ - + "\n".join(pip_freeze()) + + requirements else: requirements, conda_requirements = script_requirements.get_requirements( entry_point_filename=entry_point_filename)