From 577010c421b11ae9007ef70f2d39bd2bf583b9de Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 6 Jul 2019 23:01:15 +0300 Subject: [PATCH] Add auto requirement.txt generation --- requirements.txt | 3 +- .../backend_interface/task/repo/scriptinfo.py | 212 +++++++++++++++++- 2 files changed, 211 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 99658e63..c0229a6b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ jsonschema>=2.6.0 numpy>=1.10 opencv-python>=3.2.0.8 pathlib2>=2.3.0 +pigar>=0.9.2 plotly>=3.9.0 psutil>=3.4.2 pyhocon>=0.3.38 @@ -24,7 +25,7 @@ python-dateutil>=2.6.1 pyjwt>=1.6.4 PyYAML>=3.12 requests-file>=1.4.2 -requests>=2.18.4 +requests>=2.20.0 six>=1.11.0 tqdm>=4.19.5 typing>=3.6.4 diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 075e129e..7c3165ff 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -2,9 +2,13 @@ import os import sys import attr +import collections +import logging from furl import furl from pathlib2 import Path +from threading import Thread, Event +from ....backend_api import Session from ....debugging import get_logger from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult @@ -15,17 +19,197 @@ class ScriptInfoError(Exception): pass +class ScriptRequirements(object): + def __init__(self, root_folder): + self._root_folder = root_folder + + def get_requirements(self): + try: + from pigar import reqs + reqs.project_import_modules = ScriptRequirements._patched_project_import_modules + from pigar.__main__ import GenerateReqs + from pigar.log import logger + logger.setLevel(logging.WARNING) + installed_pkgs = reqs.get_installed_pkgs_detail() + gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs, + ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints']) + reqs, try_imports, guess = gr.extract_reqs() + return self.create_requirements_txt(reqs) + except Exception: + return '' + + @staticmethod + def _patched_project_import_modules(project_path, ignores): + """ + copied form pigar req.project_import_modules + patching, os.getcwd() is incorrectly used + """ + from pigar.modules import ImportedModules + from pigar.reqs import file_import_modules + modules = ImportedModules() + try_imports = set() + local_mods = list() + cur_dir = project_path # os.getcwd() + ignore_paths = collections.defaultdict(set) + if not ignores: + ignore_paths[project_path].add('.git') + else: + for path in ignores: + parent_dir = os.path.dirname(path) + ignore_paths[parent_dir].add(os.path.basename(path)) + + for dirpath, dirnames, files in os.walk(project_path, followlinks=True): + if dirpath in ignore_paths: + dirnames[:] = [d for d in dirnames + if d not in ignore_paths[dirpath]] + py_files = list() + for fn in files: + # C extension. + if fn.endswith('.so'): + local_mods.append(fn[:-3]) + # Normal Python file. + if fn.endswith('.py'): + local_mods.append(fn[:-3]) + py_files.append(fn) + if '__init__.py' in files: + local_mods.append(os.path.basename(dirpath)) + for file in py_files: + fpath = os.path.join(dirpath, file) + fake_path = fpath.split(cur_dir)[1][1:] + with open(fpath, 'rb') as f: + fmodules, try_ipts = file_import_modules(fake_path, f.read()) + modules |= fmodules + try_imports |= try_ipts + + return modules, try_imports, local_mods + + @staticmethod + def create_requirements_txt(reqs): + # write requirements.txt + requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' + for k, v in reqs.sorted_items(): + requirements_txt += '\n' + requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) + if k == '-e': + requirements_txt += '{0} {1}\n'.format(k, v.version) + elif v: + requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version) + else: + requirements_txt += '{0}\n'.format(k) + return requirements_txt + + +class _JupyterObserver(object): + _thread = None + _exit_event = Event() + _sample_frequency = 60. + _first_sample_frequency = 3. + + @classmethod + def observer(cls, jupyter_notebook_filename): + if cls._thread is not None: + cls._exit_event.set() + cls._thread.join() + + cls._exit_event.clear() + cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) + cls._thread.daemon = True + cls._thread.start() + + @classmethod + def _daemon(cls, jupyter_notebook_filename): + from trains import Task + + # load jupyter notebook package + # noinspection PyBroadException + try: + from nbconvert.exporters.script import ScriptExporter + _script_exporter = ScriptExporter() + except Exception: + return + # load pigar + # noinspection PyBroadException + try: + from pigar.reqs import get_installed_pkgs_detail, file_import_modules + from pigar.modules import ReqsModules + from pigar.log import logger + logger.setLevel(logging.WARNING) + except Exception: + file_import_modules = None + # main observer loop + notebook = Path(jupyter_notebook_filename) + last_update_ts = None + counter = 0 + prev_script_hash = None + while True: + if cls._exit_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency): + return + counter += 1 + # noinspection PyBroadException + try: + if not notebook.exists(): + continue + # check if notebook changed + if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: + continue + last_update_ts = notebook.stat().st_mtime + task = Task.current_task() + if not task: + continue + # get notebook python script + script_code, resources = _script_exporter.from_filename(jupyter_notebook_filename) + current_script_hash = hash(script_code) + if prev_script_hash and prev_script_hash == current_script_hash: + continue + requirements_txt = '' + # parse jupyter python script and prepare pip requirements (pigar) + # if backend supports requirements + if file_import_modules and Session.api_version > '2.1': + fmodules, _ = file_import_modules(notebook.parts[-1], script_code) + installed_pkgs = get_installed_pkgs_detail() + reqs = ReqsModules() + for name in fmodules: + if name in installed_pkgs: + pkg_name, version = installed_pkgs[name] + reqs.add(pkg_name, version, fmodules[name]) + requirements_txt = ScriptRequirements.create_requirements_txt(reqs) + + # update script + prev_script_hash = current_script_hash + data_script = task.data.script + data_script.diff = script_code + data_script.requirements = {'pip': requirements_txt} + task._update_script(script=data_script) + # update requirements + if requirements_txt: + task._update_requirements(requirements=requirements_txt) + except Exception: + pass + + class ScriptInfo(object): plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()] """ Script info detection plugins, in order of priority """ + @classmethod + def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename): + # noinspection PyBroadException + try: + if 'IPython' in sys.modules: + from IPython import get_ipython + if get_ipython(): + _JupyterObserver.observer(jupyter_notebook_filename) + except Exception: + pass + @classmethod def _get_jupyter_notebook_filename(cls): if not sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'): return None # we can safely assume that we can import the notebook package here + # noinspection PyBroadException try: from notebook.notebookapp import list_running_servers import requests @@ -51,6 +235,9 @@ class ScriptInfo(object): if not entry_point.is_file(): entry_point = (Path.cwd() / notebook_path).absolute() + # install the post store hook, so always have a synced file in the system + cls._jupyter_install_post_store_hook(entry_point.as_posix()) + # now replace the .ipynb with .py # we assume we will have that file available with the Jupyter notebook plugin entry_point = entry_point.with_suffix('.py') @@ -83,7 +270,18 @@ class ScriptInfo(object): return os.path.curdir @classmethod - def _get_script_info(cls, filepath, check_uncommitted=False, log=None): + def _get_script_code(cls, script_path): + # noinspection PyBroadException + try: + with open(script_path, 'r') as f: + script_code = f.read() + return script_code + except Exception: + pass + return '' + + @classmethod + def _get_script_info(cls, filepath, check_uncommitted=True, log=None): jupyter_filepath = cls._get_jupyter_notebook_filename() if jupyter_filepath: script_path = Path(os.path.normpath(jupyter_filepath)).absolute() @@ -121,6 +319,13 @@ class ScriptInfo(object): repo_root = repo_info.root or script_dir working_dir = cls._get_working_dir(repo_root) entry_point = cls._get_entry_point(repo_root, script_path) + diff = cls._get_script_code(script_path.as_posix()) if not plugin or not repo_info.commit else repo_info.diff + # if this is not jupyter, get the requirements.txt + requirements = '' + # create requirements if backend supports requirements + if not jupyter_filepath and Session.api_version > '2.1': + script_requirements = ScriptRequirements(Path(repo_root).as_posix()) + requirements = script_requirements.get_requirements() script_info = dict( repository=furl(repo_info.url).remove(username=True, password=True).tostr(), @@ -128,7 +333,8 @@ class ScriptInfo(object): version_num=repo_info.commit, entry_point=entry_point, working_dir=working_dir, - diff=repo_info.diff, + diff=diff, + requirements={'pip': requirements} if requirements else None, ) messages = [] @@ -145,7 +351,7 @@ class ScriptInfo(object): return ScriptInfoResult(script=script_info, warning_messages=messages) @classmethod - def get(cls, filepath=sys.argv[0], check_uncommitted=False, log=None): + def get(cls, filepath=sys.argv[0], check_uncommitted=True, log=None): try: return cls._get_script_info( filepath=filepath, check_uncommitted=check_uncommitted, log=log