Add auto requirement.txt generation

This commit is contained in:
allegroai 2019-07-06 23:01:15 +03:00
parent 3050bf1476
commit 577010c421
2 changed files with 211 additions and 4 deletions

View File

@ -17,6 +17,7 @@ jsonschema>=2.6.0
numpy>=1.10 numpy>=1.10
opencv-python>=3.2.0.8 opencv-python>=3.2.0.8
pathlib2>=2.3.0 pathlib2>=2.3.0
pigar>=0.9.2
plotly>=3.9.0 plotly>=3.9.0
psutil>=3.4.2 psutil>=3.4.2
pyhocon>=0.3.38 pyhocon>=0.3.38
@ -24,7 +25,7 @@ python-dateutil>=2.6.1
pyjwt>=1.6.4 pyjwt>=1.6.4
PyYAML>=3.12 PyYAML>=3.12
requests-file>=1.4.2 requests-file>=1.4.2
requests>=2.18.4 requests>=2.20.0
six>=1.11.0 six>=1.11.0
tqdm>=4.19.5 tqdm>=4.19.5
typing>=3.6.4 typing>=3.6.4

View File

@ -2,9 +2,13 @@ import os
import sys import sys
import attr import attr
import collections
import logging
from furl import furl from furl import furl
from pathlib2 import Path from pathlib2 import Path
from threading import Thread, Event
from ....backend_api import Session
from ....debugging import get_logger from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
@ -15,17 +19,197 @@ class ScriptInfoError(Exception):
pass pass
class ScriptRequirements(object):
def __init__(self, root_folder):
self._root_folder = root_folder
def get_requirements(self):
try:
from pigar import reqs
reqs.project_import_modules = ScriptRequirements._patched_project_import_modules
from pigar.__main__ import GenerateReqs
from pigar.log import logger
logger.setLevel(logging.WARNING)
installed_pkgs = reqs.get_installed_pkgs_detail()
gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs,
ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints'])
reqs, try_imports, guess = gr.extract_reqs()
return self.create_requirements_txt(reqs)
except Exception:
return ''
@staticmethod
def _patched_project_import_modules(project_path, ignores):
"""
copied form pigar req.project_import_modules
patching, os.getcwd() is incorrectly used
"""
from pigar.modules import ImportedModules
from pigar.reqs import file_import_modules
modules = ImportedModules()
try_imports = set()
local_mods = list()
cur_dir = project_path # os.getcwd()
ignore_paths = collections.defaultdict(set)
if not ignores:
ignore_paths[project_path].add('.git')
else:
for path in ignores:
parent_dir = os.path.dirname(path)
ignore_paths[parent_dir].add(os.path.basename(path))
for dirpath, dirnames, files in os.walk(project_path, followlinks=True):
if dirpath in ignore_paths:
dirnames[:] = [d for d in dirnames
if d not in ignore_paths[dirpath]]
py_files = list()
for fn in files:
# C extension.
if fn.endswith('.so'):
local_mods.append(fn[:-3])
# Normal Python file.
if fn.endswith('.py'):
local_mods.append(fn[:-3])
py_files.append(fn)
if '__init__.py' in files:
local_mods.append(os.path.basename(dirpath))
for file in py_files:
fpath = os.path.join(dirpath, file)
fake_path = fpath.split(cur_dir)[1][1:]
with open(fpath, 'rb') as f:
fmodules, try_ipts = file_import_modules(fake_path, f.read())
modules |= fmodules
try_imports |= try_ipts
return modules, try_imports, local_mods
@staticmethod
def create_requirements_txt(reqs):
# write requirements.txt
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
for k, v in reqs.sorted_items():
requirements_txt += '\n'
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
if k == '-e':
requirements_txt += '{0} {1}\n'.format(k, v.version)
elif v:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version)
else:
requirements_txt += '{0}\n'.format(k)
return requirements_txt
class _JupyterObserver(object):
_thread = None
_exit_event = Event()
_sample_frequency = 60.
_first_sample_frequency = 3.
@classmethod
def observer(cls, jupyter_notebook_filename):
if cls._thread is not None:
cls._exit_event.set()
cls._thread.join()
cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
cls._thread.daemon = True
cls._thread.start()
@classmethod
def _daemon(cls, jupyter_notebook_filename):
from trains import Task
# load jupyter notebook package
# noinspection PyBroadException
try:
from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter()
except Exception:
return
# load pigar
# noinspection PyBroadException
try:
from pigar.reqs import get_installed_pkgs_detail, file_import_modules
from pigar.modules import ReqsModules
from pigar.log import logger
logger.setLevel(logging.WARNING)
except Exception:
file_import_modules = None
# main observer loop
notebook = Path(jupyter_notebook_filename)
last_update_ts = None
counter = 0
prev_script_hash = None
while True:
if cls._exit_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency):
return
counter += 1
# noinspection PyBroadException
try:
if not notebook.exists():
continue
# check if notebook changed
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue
last_update_ts = notebook.stat().st_mtime
task = Task.current_task()
if not task:
continue
# get notebook python script
script_code, resources = _script_exporter.from_filename(jupyter_notebook_filename)
current_script_hash = hash(script_code)
if prev_script_hash and prev_script_hash == current_script_hash:
continue
requirements_txt = ''
# parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements
if file_import_modules and Session.api_version > '2.1':
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
installed_pkgs = get_installed_pkgs_detail()
reqs = ReqsModules()
for name in fmodules:
if name in installed_pkgs:
pkg_name, version = installed_pkgs[name]
reqs.add(pkg_name, version, fmodules[name])
requirements_txt = ScriptRequirements.create_requirements_txt(reqs)
# update script
prev_script_hash = current_script_hash
data_script = task.data.script
data_script.diff = script_code
data_script.requirements = {'pip': requirements_txt}
task._update_script(script=data_script)
# update requirements
if requirements_txt:
task._update_requirements(requirements=requirements_txt)
except Exception:
pass
class ScriptInfo(object): class ScriptInfo(object):
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()] plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
""" Script info detection plugins, in order of priority """ """ Script info detection plugins, in order of priority """
@classmethod
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename):
# noinspection PyBroadException
try:
if 'IPython' in sys.modules:
from IPython import get_ipython
if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename)
except Exception:
pass
@classmethod @classmethod
def _get_jupyter_notebook_filename(cls): def _get_jupyter_notebook_filename(cls):
if not sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'): if not sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
return None return None
# we can safely assume that we can import the notebook package here # we can safely assume that we can import the notebook package here
# noinspection PyBroadException
try: try:
from notebook.notebookapp import list_running_servers from notebook.notebookapp import list_running_servers
import requests import requests
@ -51,6 +235,9 @@ class ScriptInfo(object):
if not entry_point.is_file(): if not entry_point.is_file():
entry_point = (Path.cwd() / notebook_path).absolute() entry_point = (Path.cwd() / notebook_path).absolute()
# install the post store hook, so always have a synced file in the system
cls._jupyter_install_post_store_hook(entry_point.as_posix())
# now replace the .ipynb with .py # now replace the .ipynb with .py
# we assume we will have that file available with the Jupyter notebook plugin # we assume we will have that file available with the Jupyter notebook plugin
entry_point = entry_point.with_suffix('.py') entry_point = entry_point.with_suffix('.py')
@ -83,7 +270,18 @@ class ScriptInfo(object):
return os.path.curdir return os.path.curdir
@classmethod @classmethod
def _get_script_info(cls, filepath, check_uncommitted=False, log=None): def _get_script_code(cls, script_path):
# noinspection PyBroadException
try:
with open(script_path, 'r') as f:
script_code = f.read()
return script_code
except Exception:
pass
return ''
@classmethod
def _get_script_info(cls, filepath, check_uncommitted=True, log=None):
jupyter_filepath = cls._get_jupyter_notebook_filename() jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath: if jupyter_filepath:
script_path = Path(os.path.normpath(jupyter_filepath)).absolute() script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
@ -121,6 +319,13 @@ class ScriptInfo(object):
repo_root = repo_info.root or script_dir repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root) working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path) entry_point = cls._get_entry_point(repo_root, script_path)
diff = cls._get_script_code(script_path.as_posix()) if not plugin or not repo_info.commit else repo_info.diff
# if this is not jupyter, get the requirements.txt
requirements = ''
# create requirements if backend supports requirements
if not jupyter_filepath and Session.api_version > '2.1':
script_requirements = ScriptRequirements(Path(repo_root).as_posix())
requirements = script_requirements.get_requirements()
script_info = dict( script_info = dict(
repository=furl(repo_info.url).remove(username=True, password=True).tostr(), repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
@ -128,7 +333,8 @@ class ScriptInfo(object):
version_num=repo_info.commit, version_num=repo_info.commit,
entry_point=entry_point, entry_point=entry_point,
working_dir=working_dir, working_dir=working_dir,
diff=repo_info.diff, diff=diff,
requirements={'pip': requirements} if requirements else None,
) )
messages = [] messages = []
@ -145,7 +351,7 @@ class ScriptInfo(object):
return ScriptInfoResult(script=script_info, warning_messages=messages) return ScriptInfoResult(script=script_info, warning_messages=messages)
@classmethod @classmethod
def get(cls, filepath=sys.argv[0], check_uncommitted=False, log=None): def get(cls, filepath=sys.argv[0], check_uncommitted=True, log=None):
try: try:
return cls._get_script_info( return cls._get_script_info(
filepath=filepath, check_uncommitted=check_uncommitted, log=log filepath=filepath, check_uncommitted=check_uncommitted, log=log