From c5dd762d9bf3f1614efe01e8437dab8a0b2ae618 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 21 Jan 2020 16:32:57 +0200 Subject: [PATCH] Improve conda support --- .../backend_interface/task/repo/scriptinfo.py | 77 +++++++++++++++---- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 2fcc0bbc..67fefce8 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -5,10 +5,12 @@ from tempfile import mkstemp import attr import collections import logging +import json from furl import furl from pathlib2 import Path from threading import Thread, Event +from .util import get_command_output from ....backend_api import Session from ....debugging import get_logger from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult @@ -63,7 +65,7 @@ class ScriptRequirements(object): reqs, try_imports, guess = gr.extract_reqs() return self.create_requirements_txt(reqs) except Exception: - return '' + return '', '' @staticmethod def _patched_project_import_modules(project_path, ignores): @@ -115,6 +117,10 @@ class ScriptRequirements(object): modules |= fmodules try_imports |= try_ipts + return ScriptRequirements.add_trains_used_packages(modules), try_imports, local_mods + + @staticmethod + def add_trains_used_packages(modules): # hack: forcefully insert storage modules if we have them # noinspection PyBroadException try: @@ -146,11 +152,34 @@ class ScriptRequirements(object): except Exception: pass - return modules, try_imports, local_mods + return modules @staticmethod def create_requirements_txt(reqs): # write requirements.txt + try: + conda_requirements = '' + conda_prefix = os.environ.get('CONDA_PREFIX') + if conda_prefix and not conda_prefix.endswith(os.path.sep): + conda_prefix += os.path.sep + if conda_prefix and sys.executable.startswith(conda_prefix): + conda_packages_json = get_command_output(['conda', 'list', '--json']) + conda_packages_json = json.loads(conda_packages_json) + reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()} + for r in conda_packages_json: + # check if this is a pypi package, if it is, leave it outside + if not r.get('channel') or r.get('channel') == 'pypi': + continue + # check if we have it in our required packages + name = r['name'].lower().replace('-', '_') + # hack support pytorch/torch different naming convention + if name == 'pytorch': + name = 'torch' + k, v = reqs_lower.get(name, (None, None)) + if k: + conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version) + except: + conda_requirements = '' # python version header requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' @@ -178,7 +207,7 @@ class ScriptRequirements(object): requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k) requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) - return requirements_txt + return requirements_txt, conda_requirements class _JupyterObserver(object): @@ -206,6 +235,15 @@ class _JupyterObserver(object): def signal_sync(cls, *_): cls._sync_event.set() + @classmethod + def close(cls): + if not cls._thread: + return + cls._exit_event.set() + cls._sync_event.set() + cls._thread.join() + cls._thread = None + @classmethod def _daemon(cls, jupyter_notebook_filename): from trains import Task @@ -245,13 +283,11 @@ class _JupyterObserver(object): last_update_ts = None counter = 0 prev_script_hash = None - # main observer loop - while True: + # main observer loop, check if we need to exit + while not cls._exit_event.wait(timeout=0.): # wait for timeout or sync event cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency) - # check if we need to exit - if cls._exit_event.wait(timeout=0.): - return + cls._sync_event.clear() counter += 1 # noinspection PyBroadException @@ -283,23 +319,25 @@ class _JupyterObserver(object): if prev_script_hash and prev_script_hash == current_script_hash: continue requirements_txt = '' + conda_requirements = '' # parse jupyter python script and prepare pip requirements (pigar) # if backend supports requirements if file_import_modules and Session.check_min_api_version('2.2'): fmodules, _ = file_import_modules(notebook.parts[-1], script_code) + fmodules = ScriptRequirements.add_trains_used_packages(fmodules) installed_pkgs = get_installed_pkgs_detail() reqs = ReqsModules() for name in fmodules: if name in installed_pkgs: pkg_name, version = installed_pkgs[name] reqs.add(pkg_name, version, fmodules[name]) - requirements_txt = ScriptRequirements.create_requirements_txt(reqs) + requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs) # update script prev_script_hash = current_script_hash data_script = task.data.script data_script.diff = script_code - data_script.requirements = {'pip': requirements_txt} + data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements} task._update_script(script=data_script) # update requirements task._update_requirements(requirements=requirements_txt) @@ -491,8 +529,13 @@ class ScriptInfo(object): _log("no info for {}", script_dir) repo_root = repo_info.root or script_dir - working_dir = cls._get_working_dir(repo_root) - entry_point = cls._get_entry_point(repo_root, script_path) + if not plugin: + working_dir = '.' + entry_point = str(script_path.name) + else: + working_dir = cls._get_working_dir(repo_root) + entry_point = cls._get_entry_point(repo_root, script_path) + if check_uncommitted: diff = cls._get_script_code(script_path.as_posix()) \ if not plugin or not repo_info.commit else repo_info.diff @@ -500,16 +543,18 @@ class ScriptInfo(object): diff = '' # if this is not jupyter, get the requirements.txt requirements = '' + conda_requirements = '' # create requirements if backend supports requirements # if jupyter is present, requirements will be created in the background, when saving a snapshot if not jupyter_filepath and Session.check_min_api_version('2.2'): script_requirements = ScriptRequirements( Path(repo_root).as_posix() if repo_info.url else script_path.as_posix()) if create_requirements: - requirements = script_requirements.get_requirements() + requirements, conda_requirements = script_requirements.get_requirements() else: script_requirements = None + script_info = dict( repository=furl(repo_info.url).remove(username=True, password=True).tostr(), branch=repo_info.branch, @@ -517,7 +562,7 @@ class ScriptInfo(object): entry_point=entry_point, working_dir=working_dir, diff=diff, - requirements={'pip': requirements} if requirements else None, + requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None, ) messages = [] @@ -545,6 +590,10 @@ class ScriptInfo(object): log.warning("Failed auto-detecting task repository: {}".format(ex)) return ScriptInfoResult(), None + @classmethod + def close(cls): + _JupyterObserver.close() + @attr.s class ScriptInfoResult(object):