clearml/trains/backend_interface/task/repo/scriptinfo.py

501 lines
19 KiB
Python
Raw Normal View History

2019-06-10 17:00:28 +00:00
import os
import sys
from tempfile import mkstemp
2019-06-10 17:00:28 +00:00
import attr
2019-07-06 20:01:15 +00:00
import collections
import logging
2019-06-10 17:00:28 +00:00
from furl import furl
from pathlib2 import Path
2019-07-06 20:01:15 +00:00
from threading import Thread, Event
2019-06-10 17:00:28 +00:00
2019-07-06 20:01:15 +00:00
from ....backend_api import Session
2019-06-10 17:00:28 +00:00
from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
_logger = get_logger("Repository Detection")
class ScriptInfoError(Exception):
pass
2019-07-06 20:01:15 +00:00
class ScriptRequirements(object):
def __init__(self, root_folder):
self._root_folder = root_folder
2019-08-06 21:05:51 +00:00
@staticmethod
def get_installed_pkgs_detail(reqs):
"""
HACK: bugfix of the original pigar get_installed_pkgs_detail
Get mapping for import top level name
and install package name with version.
"""
mapping = dict()
for path in sys.path:
if os.path.isdir(path) and path.rstrip('/').endswith(
('site-packages', 'dist-packages')):
new_mapping = reqs._search_path(path)
# BUGFIX:
# override with previous, just like python resolves imports, the first match is the one used.
# unlike the original implementation, where the last one is used.
new_mapping.update(mapping)
mapping = new_mapping
return mapping
2019-07-06 20:01:15 +00:00
def get_requirements(self):
try:
from pigar import reqs
reqs.project_import_modules = ScriptRequirements._patched_project_import_modules
from pigar.__main__ import GenerateReqs
from pigar.log import logger
logger.setLevel(logging.WARNING)
2019-08-06 21:05:51 +00:00
try:
# first try our version, if we fail revert to the internal implantation
installed_pkgs = self.get_installed_pkgs_detail(reqs)
except Exception:
installed_pkgs = reqs.get_installed_pkgs_detail()
2019-07-06 20:01:15 +00:00
gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs,
ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints'])
reqs, try_imports, guess = gr.extract_reqs()
return self.create_requirements_txt(reqs)
except Exception:
return ''
@staticmethod
def _patched_project_import_modules(project_path, ignores):
"""
copied form pigar req.project_import_modules
patching, os.getcwd() is incorrectly used
"""
from pigar.modules import ImportedModules
from pigar.reqs import file_import_modules
modules = ImportedModules()
try_imports = set()
local_mods = list()
cur_dir = project_path # os.getcwd()
ignore_paths = collections.defaultdict(set)
if not ignores:
ignore_paths[project_path].add('.git')
else:
for path in ignores:
parent_dir = os.path.dirname(path)
ignore_paths[parent_dir].add(os.path.basename(path))
for dirpath, dirnames, files in os.walk(project_path, followlinks=True):
if dirpath in ignore_paths:
dirnames[:] = [d for d in dirnames
if d not in ignore_paths[dirpath]]
py_files = list()
for fn in files:
# C extension.
if fn.endswith('.so'):
local_mods.append(fn[:-3])
# Normal Python file.
if fn.endswith('.py'):
local_mods.append(fn[:-3])
py_files.append(fn)
if '__init__.py' in files:
local_mods.append(os.path.basename(dirpath))
for file in py_files:
fpath = os.path.join(dirpath, file)
fake_path = fpath.split(cur_dir)[1][1:]
with open(fpath, 'rb') as f:
fmodules, try_ipts = file_import_modules(fake_path, f.read())
modules |= fmodules
try_imports |= try_ipts
2019-09-13 14:09:58 +00:00
# hack: forcefully insert storage modules if we have them
# noinspection PyBroadException
try:
import boto3
modules.add('boto3', 'trains.storage', 0)
except Exception:
pass
# noinspection PyBroadException
try:
from google.cloud import storage
modules.add('google_cloud_storage', 'trains.storage', 0)
except Exception:
pass
# noinspection PyBroadException
try:
from azure.storage.blob import ContentSettings
modules.add('azure_storage_blob', 'trains.storage', 0)
except Exception:
pass
2019-07-06 20:01:15 +00:00
return modules, try_imports, local_mods
@staticmethod
def create_requirements_txt(reqs):
# write requirements.txt
# python version header
2019-07-06 20:01:15 +00:00
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
# requirement summary
requirements_txt += '\n'
2019-07-06 20:01:15 +00:00
for k, v in reqs.sorted_items():
# requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
2019-07-06 20:01:15 +00:00
if k == '-e':
requirements_txt += '{0} {1}\n'.format(k, v.version)
elif v:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version)
else:
requirements_txt += '{0}\n'.format(k)
# requirements details (in comments)
requirements_txt += '\n' + \
'# Detailed import analysis\n' \
'# **************************\n'
for k, v in reqs.sorted_items():
requirements_txt += '\n'
if k == '-e':
requirements_txt += '# IMPORT PACKAGE {0} {1}\n'.format(k, v.version)
else:
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
2019-07-06 20:01:15 +00:00
return requirements_txt
class _JupyterObserver(object):
_thread = None
_exit_event = Event()
_sync_event = Event()
_sample_frequency = 30.
2019-07-06 20:01:15 +00:00
_first_sample_frequency = 3.
@classmethod
def observer(cls, jupyter_notebook_filename):
if cls._thread is not None:
# order of signaling is important!
2019-07-06 20:01:15 +00:00
cls._exit_event.set()
cls._sync_event.set()
2019-07-06 20:01:15 +00:00
cls._thread.join()
cls._sync_event.clear()
2019-07-06 20:01:15 +00:00
cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
cls._thread.daemon = True
cls._thread.start()
@classmethod
def signal_sync(cls, *_):
cls._sync_event.set()
2019-07-06 20:01:15 +00:00
@classmethod
def _daemon(cls, jupyter_notebook_filename):
from trains import Task
# load jupyter notebook package
# noinspection PyBroadException
try:
from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter()
except Exception:
return
# load pigar
# noinspection PyBroadException
try:
from pigar.reqs import get_installed_pkgs_detail, file_import_modules
from pigar.modules import ReqsModules
from pigar.log import logger
logger.setLevel(logging.WARNING)
except Exception:
file_import_modules = None
# load IPython
# noinspection PyBroadException
try:
from IPython import get_ipython
except Exception:
# should not happen
get_ipython = None
# setup local notebook files
if jupyter_notebook_filename:
notebook = Path(jupyter_notebook_filename)
local_jupyter_filename = jupyter_notebook_filename
else:
notebook = None
fd, local_jupyter_filename = mkstemp(suffix='.ipynb')
os.close(fd)
2019-07-06 20:01:15 +00:00
last_update_ts = None
counter = 0
prev_script_hash = None
# main observer loop
2019-07-06 20:01:15 +00:00
while True:
# wait for timeout or sync event
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
# check if we need to exit
if cls._exit_event.wait(timeout=0.):
2019-07-06 20:01:15 +00:00
return
cls._sync_event.clear()
2019-07-06 20:01:15 +00:00
counter += 1
# noinspection PyBroadException
try:
# if there is no task connected, do nothing
2019-07-06 20:01:15 +00:00
task = Task.current_task()
if not task:
continue
# if we have a local file:
if notebook:
if not notebook.exists():
continue
# check if notebook changed
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue
last_update_ts = notebook.stat().st_mtime
else:
# serialize notebook to a temp file
# noinspection PyBroadException
try:
get_ipython().run_line_magic('notebook', local_jupyter_filename)
except Exception as ex:
continue
2019-07-06 20:01:15 +00:00
# get notebook python script
script_code, resources = _script_exporter.from_filename(local_jupyter_filename)
2019-07-06 20:01:15 +00:00
current_script_hash = hash(script_code)
if prev_script_hash and prev_script_hash == current_script_hash:
continue
requirements_txt = ''
# parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements
2019-09-13 14:09:58 +00:00
if file_import_modules and Session.check_min_api_version('2.2'):
2019-07-06 20:01:15 +00:00
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
installed_pkgs = get_installed_pkgs_detail()
reqs = ReqsModules()
for name in fmodules:
if name in installed_pkgs:
pkg_name, version = installed_pkgs[name]
reqs.add(pkg_name, version, fmodules[name])
requirements_txt = ScriptRequirements.create_requirements_txt(reqs)
# update script
prev_script_hash = current_script_hash
data_script = task.data.script
data_script.diff = script_code
data_script.requirements = {'pip': requirements_txt}
task._update_script(script=data_script)
# update requirements
task._update_requirements(requirements=requirements_txt)
2019-07-06 20:01:15 +00:00
except Exception:
pass
2019-06-10 17:00:28 +00:00
class ScriptInfo(object):
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
""" Script info detection plugins, in order of priority """
2019-07-06 20:01:15 +00:00
@classmethod
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename):
# noinspection PyBroadException
try:
if 'IPython' in sys.modules:
from IPython import get_ipython
if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename)
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
2019-07-06 20:01:15 +00:00
except Exception:
pass
2019-06-10 17:00:28 +00:00
@classmethod
def _get_jupyter_notebook_filename(cls):
if not (sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or
sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \
or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
2019-06-10 17:00:28 +00:00
return None
# we can safely assume that we can import the notebook package here
2019-07-06 20:01:15 +00:00
# noinspection PyBroadException
2019-06-10 17:00:28 +00:00
try:
from notebook.notebookapp import list_running_servers
import requests
2019-06-15 23:29:05 +00:00
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
2019-06-10 17:00:28 +00:00
server_info = next(list_running_servers())
r = requests.get(
url=server_info['url'] + 'api/sessions',
headers={'Authorization': 'token {}'.format(server_info.get('token', '')), })
r.raise_for_status()
notebooks = r.json()
cur_notebook = None
for n in notebooks:
if n['kernel']['id'] == current_kernel:
cur_notebook = n
break
notebook_path = cur_notebook['notebook'].get('path', '')
notebook_name = cur_notebook['notebook'].get('name', '')
is_google_colab = False
# check if this is google.colab, then there is no local file
# noinspection PyBroadException
try:
from IPython import get_ipython
if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded:
is_google_colab = True
except Exception:
pass
if is_google_colab:
script_entry_point = notebook_name
local_ipynb_file = None
else:
# always slash, because this is from uri (so never backslash not even oon windows)
entry_point_filename = notebook_path.split('/')[-1]
# now we should try to find the actual file
entry_point = (Path.cwd() / entry_point_filename).absolute()
if not entry_point.is_file():
entry_point = (Path.cwd() / notebook_path).absolute()
# get local ipynb for observer
local_ipynb_file = entry_point.as_posix()
2019-06-10 17:00:28 +00:00
# now replace the .ipynb with .py
# we assume we will have that file available with the Jupyter notebook plugin
entry_point = entry_point.with_suffix('.py')
2019-06-10 17:00:28 +00:00
script_entry_point = entry_point.as_posix()
2019-07-06 20:01:15 +00:00
# install the post store hook,
# notice that if we do not have a local file we serialize/write every time the entire notebook
cls._jupyter_install_post_store_hook(local_ipynb_file)
2019-06-10 17:00:28 +00:00
return script_entry_point
2019-06-10 17:00:28 +00:00
except Exception:
return None
@classmethod
def _get_entry_point(cls, repo_root, script_path):
repo_root = Path(repo_root).absolute()
try:
# Use os.path.relpath as it calculates up dir movements (../)
entry_point = os.path.relpath(str(script_path), str(Path.cwd()))
except ValueError:
# Working directory not under repository root
entry_point = script_path.relative_to(repo_root)
return Path(entry_point).as_posix()
@classmethod
def _get_working_dir(cls, repo_root):
repo_root = Path(repo_root).absolute()
try:
return Path.cwd().relative_to(repo_root).as_posix()
except ValueError:
# Working directory not under repository root
return os.path.curdir
@classmethod
2019-07-06 20:01:15 +00:00
def _get_script_code(cls, script_path):
# noinspection PyBroadException
try:
with open(script_path, 'r') as f:
script_code = f.read()
return script_code
except Exception:
pass
return ''
@classmethod
def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None):
2019-06-10 17:00:28 +00:00
jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath:
script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
else:
script_path = Path(os.path.normpath(filepath)).absolute()
if not script_path.is_file():
raise ScriptInfoError(
"Script file [{}] could not be found".format(filepath)
)
script_dir = script_path.parent
def _log(msg, *args, **kwargs):
if not log:
return
log.warning(
"Failed auto-detecting task repository: {}".format(
msg.format(*args, **kwargs)
)
)
plugin = next((p for p in cls.plugins if p.exists(script_dir)), None)
repo_info = DetectionResult()
if not plugin:
log.info("No repository found, storing script code instead")
2019-06-10 17:00:28 +00:00
else:
try:
repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted)
except Exception as ex:
_log("no info for {} ({})", script_dir, ex)
else:
if repo_info.is_empty():
_log("no info for {}", script_dir)
repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
if check_uncommitted:
diff = cls._get_script_code(script_path.as_posix()) \
if not plugin or not repo_info.commit else repo_info.diff
else:
diff = ''
# if this is not jupyter, get the requirements.txt
2019-07-06 20:01:15 +00:00
requirements = ''
# create requirements if backend supports requirements
2019-09-13 14:09:58 +00:00
if create_requirements and not jupyter_filepath and Session.check_min_api_version('2.2'):
2019-07-06 20:01:15 +00:00
script_requirements = ScriptRequirements(Path(repo_root).as_posix())
requirements = script_requirements.get_requirements()
2019-06-10 17:00:28 +00:00
script_info = dict(
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
branch=repo_info.branch,
version_num=repo_info.commit,
entry_point=entry_point,
working_dir=working_dir,
2019-07-06 20:01:15 +00:00
diff=diff,
requirements={'pip': requirements} if requirements else None,
2019-06-10 17:00:28 +00:00
)
messages = []
if repo_info.modified:
messages.append(
"======> WARNING! UNCOMMITTED CHANGES IN REPOSITORY {} <======".format(
script_info.get("repository", "")
)
)
if not any(script_info.values()):
script_info = None
return ScriptInfoResult(script=script_info, warning_messages=messages)
@classmethod
def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None):
2019-06-10 17:00:28 +00:00
try:
return cls._get_script_info(
filepath=filepath, check_uncommitted=check_uncommitted,
create_requirements=create_requirements, log=log)
2019-06-10 17:00:28 +00:00
except Exception as ex:
if log:
log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult()
@attr.s
class ScriptInfoResult(object):
script = attr.ib(default=None)
warning_messages = attr.ib(factory=list)