clearml/trains/backend_interface/task/repo/scriptinfo.py

872 lines
35 KiB
Python
Raw Normal View History

2019-06-10 17:00:28 +00:00
import os
import sys
from copy import copy
from functools import partial
from tempfile import mkstemp
2019-06-10 17:00:28 +00:00
import attr
2019-07-06 20:01:15 +00:00
import logging
2020-01-21 14:32:57 +00:00
import json
2019-06-10 17:00:28 +00:00
from pathlib2 import Path
2019-07-06 20:01:15 +00:00
from threading import Thread, Event
2019-06-10 17:00:28 +00:00
from .util import get_command_output, remove_user_pass_from_url
2019-07-06 20:01:15 +00:00
from ....backend_api import Session
2019-06-10 17:00:28 +00:00
from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
_logger = get_logger("Repository Detection")
class ScriptInfoError(Exception):
pass
2019-07-06 20:01:15 +00:00
class ScriptRequirements(object):
2020-05-24 05:16:12 +00:00
_max_requirements_size = 512 * 1024
2020-05-22 08:02:52 +00:00
2019-07-06 20:01:15 +00:00
def __init__(self, root_folder):
self._root_folder = root_folder
def get_requirements(self, entry_point_filename=None):
# noinspection PyBroadException
2019-07-06 20:01:15 +00:00
try:
2020-03-01 15:12:28 +00:00
from ....utilities.pigar.reqs import get_installed_pkgs_detail
from ....utilities.pigar.__main__ import GenerateReqs
installed_pkgs = get_installed_pkgs_detail()
2019-07-06 20:01:15 +00:00
gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs,
ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints',
'site-packages', 'dist-packages'])
reqs, try_imports, guess, local_pks = gr.extract_reqs(
module_callback=ScriptRequirements.add_trains_used_packages, entry_point_filename=entry_point_filename)
2020-03-01 15:12:28 +00:00
return self.create_requirements_txt(reqs, local_pks)
2019-07-06 20:01:15 +00:00
except Exception:
2020-01-21 14:32:57 +00:00
return '', ''
2019-07-06 20:01:15 +00:00
2020-01-21 14:32:57 +00:00
@staticmethod
def add_trains_used_packages(modules):
2019-09-13 14:09:58 +00:00
# hack: forcefully insert storage modules if we have them
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements,PyUnresolvedReferences
2020-07-04 19:52:09 +00:00
import boto3 # noqa: F401
2019-09-13 14:09:58 +00:00
modules.add('boto3', 'trains.storage', 0)
except Exception:
pass
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements,PyUnresolvedReferences
2020-07-04 19:52:09 +00:00
from google.cloud import storage # noqa: F401
2019-09-13 14:09:58 +00:00
modules.add('google_cloud_storage', 'trains.storage', 0)
except Exception:
pass
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements,PyUnresolvedReferences
2020-07-04 19:52:09 +00:00
from azure.storage.blob import ContentSettings # noqa: F401
2019-09-13 14:09:58 +00:00
modules.add('azure_storage_blob', 'trains.storage', 0)
except Exception:
pass
2020-04-26 19:54:49 +00:00
# bugfix, replace sklearn with scikit-learn name
if 'sklearn' in modules:
sklearn = modules.pop('sklearn', {})
for fname, lines in sklearn.items():
modules.add('scikit_learn', fname, lines)
# if we have torch and it supports tensorboard, we should add that as well
# (because it will not be detected automatically)
if 'torch' in modules and 'tensorboard' not in modules:
# noinspection PyBroadException
try:
# see if this version of torch support tensorboard
# noinspection PyPackageRequirements,PyUnresolvedReferences
2020-07-04 19:52:09 +00:00
import torch.utils.tensorboard # noqa: F401
# noinspection PyPackageRequirements,PyUnresolvedReferences
2020-07-04 19:52:09 +00:00
import tensorboard # noqa: F401
modules.add('tensorboard', 'torch', 0)
except Exception:
pass
2019-09-13 14:09:58 +00:00
# remove setuptools, we should not specify this module version. It is installed by default
if 'setuptools' in modules:
modules.pop('setuptools', {})
# add forced requirements:
# noinspection PyBroadException
try:
from ..task import Task
# noinspection PyProtectedMember
for package, version in Task._force_requirements.items():
modules.add(package, 'trains', 0)
except Exception:
pass
2020-01-21 14:32:57 +00:00
return modules
2019-07-06 20:01:15 +00:00
@staticmethod
2020-03-01 15:12:28 +00:00
def create_requirements_txt(reqs, local_pks=None):
2019-07-06 20:01:15 +00:00
# write requirements.txt
# noinspection PyBroadException
2020-01-21 14:32:57 +00:00
try:
conda_requirements = ''
conda_prefix = os.environ.get('CONDA_PREFIX')
if conda_prefix and not conda_prefix.endswith(os.path.sep):
conda_prefix += os.path.sep
if conda_prefix and sys.executable.startswith(conda_prefix):
conda_packages_json = get_command_output(['conda', 'list', '--json'])
conda_packages_json = json.loads(conda_packages_json)
reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()}
for r in conda_packages_json:
# check if this is a pypi package, if it is, leave it outside
if not r.get('channel') or r.get('channel') == 'pypi':
continue
# check if we have it in our required packages
name = r['name'].lower().replace('-', '_')
# hack support pytorch/torch different naming convention
if name == 'pytorch':
name = 'torch'
k, v = reqs_lower.get(name, (None, None))
if k and v is not None:
if v.version:
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
else:
conda_requirements += '{0}\n'.format(k)
except Exception:
2020-01-21 14:32:57 +00:00
conda_requirements = ''
# add forced requirements:
# noinspection PyBroadException
try:
from ..task import Task
# noinspection PyProtectedMember
forced_packages = copy(Task._force_requirements)
except Exception:
forced_packages = {}
# python version header
2019-07-06 20:01:15 +00:00
requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n'
2020-03-01 15:12:28 +00:00
if local_pks:
requirements_txt += '\n# Local modules found - skipping:\n'
for k, v in local_pks.sorted_items():
if v.version:
requirements_txt += '# {0} == {1}\n'.format(k, v.version)
else:
requirements_txt += '# {0}\n'.format(k)
2020-03-01 15:12:28 +00:00
# requirement summary
requirements_txt += '\n'
2019-07-06 20:01:15 +00:00
for k, v in reqs.sorted_items():
version = v.version
if k in forced_packages:
forced_version = forced_packages.pop(k, None)
if forced_version:
version = forced_version
# requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
if k == '-e' and version:
requirements_txt += '{0}\n'.format(version)
elif k.startswith('-e '):
requirements_txt += '{0} {1}\n'.format(k.replace('-e ', '', 1), version or '')
elif version:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', version)
else:
requirements_txt += '{0}\n'.format(k)
# add forced requirements that we could not find installed on the system
for k in sorted(forced_packages.keys()):
if forced_packages[k]:
requirements_txt += '{0} {1} {2}\n'.format(k, '==', forced_packages[k])
2019-07-06 20:01:15 +00:00
else:
requirements_txt += '{0}\n'.format(k)
2020-05-22 08:02:52 +00:00
requirements_txt_packages_only = \
requirements_txt + '\n# Skipping detailed import analysis, it is too large\n'
# requirements details (in comments)
requirements_txt += '\n' + \
'# Detailed import analysis\n' \
'# **************************\n'
2020-03-01 15:12:28 +00:00
if local_pks:
for k, v in local_pks.sorted_items():
requirements_txt += '\n'
requirements_txt += '# IMPORT LOCAL PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
for k, v in reqs.sorted_items():
requirements_txt += '\n'
if k == '-e':
requirements_txt += '# IMPORT PACKAGE {0} {1}\n'.format(k, v.version)
else:
requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k)
requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()])
2020-05-22 08:02:52 +00:00
# make sure we do not exceed the size a size limit
return (requirements_txt if len(requirements_txt) < ScriptRequirements._max_requirements_size
else requirements_txt_packages_only,
conda_requirements)
2019-07-06 20:01:15 +00:00
class _JupyterObserver(object):
_thread = None
_exit_event = Event()
_sync_event = Event()
_sample_frequency = 30.
2019-07-06 20:01:15 +00:00
_first_sample_frequency = 3.
_jupyter_history_logger = None
2019-07-06 20:01:15 +00:00
@classmethod
def observer(cls, jupyter_notebook_filename, log_history):
2019-07-06 20:01:15 +00:00
if cls._thread is not None:
# order of signaling is important!
2019-07-06 20:01:15 +00:00
cls._exit_event.set()
cls._sync_event.set()
2019-07-06 20:01:15 +00:00
cls._thread.join()
if log_history and cls._jupyter_history_logger is None:
cls._jupyter_history_logger = _JupyterHistoryLogger()
cls._jupyter_history_logger.hook()
cls._sync_event.clear()
2019-07-06 20:01:15 +00:00
cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
cls._thread.daemon = True
cls._thread.start()
@classmethod
def signal_sync(cls, *_, **__):
cls._sync_event.set()
2020-01-21 14:32:57 +00:00
@classmethod
def close(cls):
if not cls._thread:
return
cls._exit_event.set()
cls._sync_event.set()
cls._thread.join()
cls._thread = None
2019-07-06 20:01:15 +00:00
@classmethod
def _daemon(cls, jupyter_notebook_filename):
from trains import Task
# load jupyter notebook package
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
2019-07-06 20:01:15 +00:00
from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter()
except Exception:
return
# load pigar
# noinspection PyBroadException
try:
2020-03-01 15:12:28 +00:00
from ....utilities.pigar.reqs import get_installed_pkgs_detail, file_import_modules
from ....utilities.pigar.modules import ReqsModules
from ....utilities.pigar.log import logger
2019-07-06 20:01:15 +00:00
logger.setLevel(logging.WARNING)
except Exception:
file_import_modules = None
# load IPython
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from IPython import get_ipython
except Exception:
# should not happen
get_ipython = None
# setup local notebook files
if jupyter_notebook_filename:
notebook = Path(jupyter_notebook_filename)
local_jupyter_filename = jupyter_notebook_filename
else:
notebook = None
fd, local_jupyter_filename = mkstemp(suffix='.ipynb')
os.close(fd)
2019-07-06 20:01:15 +00:00
last_update_ts = None
counter = 0
prev_script_hash = None
# noinspection PyBroadException
try:
from ....version import __version__
our_module = cls.__module__.split('.')[0], __version__
except Exception:
our_module = None
# noinspection PyBroadException
try:
import re
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)')
except Exception:
replace_ipython_pattern = None
2020-01-21 14:32:57 +00:00
# main observer loop, check if we need to exit
while not cls._exit_event.wait(timeout=0.):
# wait for timeout or sync event
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
2020-01-21 14:32:57 +00:00
cls._sync_event.clear()
2019-07-06 20:01:15 +00:00
counter += 1
# noinspection PyBroadException
try:
# if there is no task connected, do nothing
2019-07-06 20:01:15 +00:00
task = Task.current_task()
if not task:
continue
script_code = None
fmodules = None
current_cell = None
# if we have a local file:
if notebook:
if not notebook.exists():
continue
# check if notebook changed
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue
last_update_ts = notebook.stat().st_mtime
else:
# serialize notebook to a temp file
if cls._jupyter_history_logger:
script_code, current_cell = cls._jupyter_history_logger.history_to_str()
else:
# noinspection PyBroadException
try:
# noinspection PyBroadException
try:
os.unlink(local_jupyter_filename)
except Exception:
pass
get_ipython().run_line_magic('history', '-t -f {}'.format(local_jupyter_filename))
with open(local_jupyter_filename, 'r') as f:
script_code = f.read()
# load the modules
from ....utilities.pigar.modules import ImportedModules
fmodules = ImportedModules()
for nm in set([str(m).split('.')[0] for m in sys.modules]):
fmodules.add(nm, 'notebook', 0)
except Exception:
continue
2019-07-06 20:01:15 +00:00
# get notebook python script
if script_code is None:
script_code, _ = _script_exporter.from_filename(local_jupyter_filename)
current_script_hash = hash(script_code + (current_cell or ''))
2019-07-06 20:01:15 +00:00
if prev_script_hash and prev_script_hash == current_script_hash:
continue
# remove ipython direct access from the script code
# we will not be able to run them anyhow
if replace_ipython_pattern:
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
2019-07-06 20:01:15 +00:00
requirements_txt = ''
2020-01-21 14:32:57 +00:00
conda_requirements = ''
2019-07-06 20:01:15 +00:00
# parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements
2019-09-13 14:09:58 +00:00
if file_import_modules and Session.check_min_api_version('2.2'):
if fmodules is None:
fmodules, _ = file_import_modules(
notebook.parts[-1] if notebook else 'notebook', script_code)
if current_cell:
cell_fmodules, _ = file_import_modules(
notebook.parts[-1] if notebook else 'notebook', current_cell)
# noinspection PyBroadException
try:
fmodules |= cell_fmodules
except Exception:
pass
# add current cell to the script
if current_cell:
script_code += '\n' + current_cell
2020-01-21 14:32:57 +00:00
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
# noinspection PyUnboundLocalVariable
2019-07-06 20:01:15 +00:00
installed_pkgs = get_installed_pkgs_detail()
# make sure we are in installed packages
if our_module and (our_module[0] not in installed_pkgs):
installed_pkgs[our_module[0]] = our_module
# noinspection PyUnboundLocalVariable
2019-07-06 20:01:15 +00:00
reqs = ReqsModules()
for name in fmodules:
if name in installed_pkgs:
pkg_name, version = installed_pkgs[name]
reqs.add(pkg_name, version, fmodules[name])
2020-01-21 14:32:57 +00:00
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
2019-07-06 20:01:15 +00:00
# update script
prev_script_hash = current_script_hash
data_script = task.data.script
data_script.diff = script_code
2020-01-21 14:32:57 +00:00
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
# noinspection PyProtectedMember
2019-07-06 20:01:15 +00:00
task._update_script(script=data_script)
# update requirements
# noinspection PyProtectedMember
task._update_requirements(requirements=requirements_txt)
2019-07-06 20:01:15 +00:00
except Exception:
pass
2019-06-10 17:00:28 +00:00
class ScriptInfo(object):
2020-07-01 22:26:16 +00:00
max_diff_size_bytes = 500000
2019-06-10 17:00:28 +00:00
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
""" Script info detection plugins, in order of priority """
2019-07-06 20:01:15 +00:00
@classmethod
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False):
2019-07-06 20:01:15 +00:00
# noinspection PyBroadException
try:
if 'IPython' in sys.modules:
# noinspection PyPackageRequirements
2019-07-06 20:01:15 +00:00
from IPython import get_ipython
if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename, log_history)
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
if log_history:
get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync)
2019-07-06 20:01:15 +00:00
except Exception:
pass
2019-06-10 17:00:28 +00:00
@classmethod
def _get_jupyter_notebook_filename(cls):
2020-05-24 05:16:12 +00:00
if not (sys.argv[0].endswith(os.path.sep + 'ipykernel_launcher.py') or
sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \
or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
2019-06-10 17:00:28 +00:00
return None
# we can safely assume that we can import the notebook package here
2019-07-06 20:01:15 +00:00
# noinspection PyBroadException
2019-06-10 17:00:28 +00:00
try:
# noinspection PyPackageRequirements
2019-06-10 17:00:28 +00:00
from notebook.notebookapp import list_running_servers
import requests
2019-06-15 23:29:05 +00:00
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
# noinspection PyBroadException
try:
server_info = next(list_running_servers())
except Exception:
# on some jupyter notebook versions this function can crash on parsing the json file,
# we will parse it manually here
# noinspection PyPackageRequirements
import ipykernel
from glob import glob
import json
for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), 'nbserver-*.json')):
# noinspection PyBroadException
try:
with open(f, 'r') as json_data:
server_info = json.load(json_data)
except Exception:
server_info = None
if server_info:
break
try:
r = requests.get(
url=server_info['url'] + 'api/sessions',
headers={'Authorization': 'token {}'.format(server_info.get('token', '')), })
except requests.exceptions.SSLError:
# disable SSL check warning
from urllib3.exceptions import InsecureRequestWarning
# noinspection PyUnresolvedReferences
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
# fire request
r = requests.get(
url=server_info['url'] + 'api/sessions',
headers={'Authorization': 'token {}'.format(server_info.get('token', '')), }, verify=False)
# enable SSL check warning
import warnings
warnings.simplefilter('default', InsecureRequestWarning)
2019-06-10 17:00:28 +00:00
r.raise_for_status()
notebooks = r.json()
cur_notebook = None
for n in notebooks:
if n['kernel']['id'] == current_kernel:
cur_notebook = n
break
notebook_path = cur_notebook['notebook'].get('path', '')
notebook_name = cur_notebook['notebook'].get('name', '')
is_google_colab = False
# check if this is google.colab, then there is no local file
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from IPython import get_ipython
if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded:
is_google_colab = True
except Exception:
pass
if is_google_colab:
script_entry_point = str(notebook_name or 'notebook').replace(
'>', '_').replace('<', '_').replace('.ipynb', '.py')
if not script_entry_point.lower().endswith('.py'):
script_entry_point += '.py'
local_ipynb_file = None
else:
# always slash, because this is from uri (so never backslash not even oon windows)
entry_point_filename = notebook_path.split('/')[-1]
# now we should try to find the actual file
entry_point = (Path.cwd() / entry_point_filename).absolute()
if not entry_point.is_file():
entry_point = (Path.cwd() / notebook_path).absolute()
# get local ipynb for observer
local_ipynb_file = entry_point.as_posix()
2019-06-10 17:00:28 +00:00
# now replace the .ipynb with .py
# we assume we will have that file available with the Jupyter notebook plugin
entry_point = entry_point.with_suffix('.py')
2019-06-10 17:00:28 +00:00
script_entry_point = entry_point.as_posix()
2019-07-06 20:01:15 +00:00
# install the post store hook,
# notice that if we do not have a local file we serialize/write every time the entire notebook
cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab)
2019-06-10 17:00:28 +00:00
return script_entry_point
2019-06-10 17:00:28 +00:00
except Exception:
return None
@classmethod
def _get_entry_point(cls, repo_root, script_path):
repo_root = Path(repo_root).absolute()
try:
# Use os.path.relpath as it calculates up dir movements (../)
entry_point = os.path.relpath(str(script_path), str(Path.cwd()))
except ValueError:
# Working directory not under repository root
entry_point = script_path.relative_to(repo_root)
return Path(entry_point).as_posix()
@classmethod
def _get_working_dir(cls, repo_root):
repo_root = Path(repo_root).absolute()
try:
return Path.cwd().relative_to(repo_root).as_posix()
except ValueError:
# Working directory not under repository root
return os.path.curdir
@classmethod
2019-07-06 20:01:15 +00:00
def _get_script_code(cls, script_path):
# noinspection PyBroadException
try:
with open(script_path, 'r') as f:
script_code = f.read()
return script_code
except Exception:
pass
return ''
@classmethod
def _get_script_info(cls, filepaths, check_uncommitted=True, create_requirements=True, log=None):
2019-06-10 17:00:28 +00:00
jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath:
scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()]
2019-06-10 17:00:28 +00:00
else:
scripts_path = [Path(os.path.normpath(f)).absolute() for f in filepaths if f]
if all(not f.is_file() for f in scripts_path):
2019-06-10 17:00:28 +00:00
raise ScriptInfoError(
"Script file {} could not be found".format(scripts_path)
2019-06-10 17:00:28 +00:00
)
scripts_dir = [f.parent for f in scripts_path]
2019-06-10 17:00:28 +00:00
def _log(msg, *args, **kwargs):
if not log:
return
log.warning(
"Failed auto-detecting task repository: {}".format(
msg.format(*args, **kwargs)
)
)
plugin = next((p for p in cls.plugins if any(p.exists(d) for d in scripts_dir)), None)
2019-06-10 17:00:28 +00:00
repo_info = DetectionResult()
script_dir = scripts_dir[0]
script_path = scripts_path[0]
2020-07-01 22:26:16 +00:00
messages = []
auxiliary_git_diff = None
2020-07-01 22:26:16 +00:00
2019-06-10 17:00:28 +00:00
if not plugin:
log.info("No repository found, storing script code instead")
2019-06-10 17:00:28 +00:00
else:
try:
for i, d in enumerate(scripts_dir):
repo_info = plugin.get_info(str(d), include_diff=check_uncommitted)
if not repo_info.is_empty():
script_dir = d
script_path = scripts_path[i]
break
2019-06-10 17:00:28 +00:00
except Exception as ex:
_log("no info for {} ({})", scripts_dir, ex)
2019-06-10 17:00:28 +00:00
else:
if repo_info.is_empty():
_log("no info for {}", scripts_dir)
2019-06-10 17:00:28 +00:00
repo_root = repo_info.root or script_dir
2020-01-21 14:32:57 +00:00
if not plugin:
working_dir = '.'
entry_point = str(script_path.name)
else:
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
if check_uncommitted:
diff = cls._get_script_code(script_path.as_posix()) \
if not plugin or not repo_info.commit else repo_info.diff
2020-07-01 22:26:16 +00:00
# make sure diff is not too big:
if len(diff) > cls.max_diff_size_bytes:
messages.append(
"======> WARNING! Git diff to large to store "
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024))
auxiliary_git_diff = diff
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
'# Clear the section before enqueueing Task!\n'
2020-07-01 22:26:16 +00:00
else:
diff = ''
# if this is not jupyter, get the requirements.txt
2019-07-06 20:01:15 +00:00
requirements = ''
2020-01-21 14:32:57 +00:00
conda_requirements = ''
2019-07-06 20:01:15 +00:00
# create requirements if backend supports requirements
# if jupyter is present, requirements will be created in the background, when saving a snapshot
if not jupyter_filepath and Session.check_min_api_version('2.2'):
script_requirements = ScriptRequirements(
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
if create_requirements:
2020-01-21 14:32:57 +00:00
requirements, conda_requirements = script_requirements.get_requirements()
else:
script_requirements = None
2019-06-10 17:00:28 +00:00
script_info = dict(
repository=remove_user_pass_from_url(repo_info.url),
2019-06-10 17:00:28 +00:00
branch=repo_info.branch,
version_num=repo_info.commit,
entry_point=entry_point,
working_dir=working_dir,
2019-07-06 20:01:15 +00:00
diff=diff,
2020-01-21 14:32:57 +00:00
requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None,
binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor),
repo_root=repo_root,
jupyter_filepath=jupyter_filepath,
2019-06-10 17:00:28 +00:00
)
if repo_info.modified:
messages.append(
"======> WARNING! UNCOMMITTED CHANGES IN REPOSITORY {} <======".format(
script_info.get("repository", "")
)
)
if not any(script_info.values()):
script_info = None
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
script_requirements)
2019-06-10 17:00:28 +00:00
@classmethod
def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None):
2019-06-10 17:00:28 +00:00
try:
if not filepaths:
filepaths = [sys.argv[0], ]
2019-06-10 17:00:28 +00:00
return cls._get_script_info(
filepaths=filepaths, check_uncommitted=check_uncommitted,
create_requirements=create_requirements, log=log)
2019-06-10 17:00:28 +00:00
except Exception as ex:
if log:
log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult(), None
2019-06-10 17:00:28 +00:00
@classmethod
def is_running_from_module(cls):
# noinspection PyBroadException
try:
return '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__']
except Exception:
return False
@classmethod
def detect_running_module(cls, script_dict):
# noinspection PyBroadException
try:
# If this is jupyter, do not try to detect the running module, we know what we have.
if script_dict.get('jupyter_filepath'):
return script_dict
if cls.is_running_from_module():
argvs = ''
git_root = os.path.abspath(script_dict['repo_root']) if script_dict['repo_root'] else None
for a in sys.argv[1:]:
if git_root and os.path.exists(a):
# check if common to project:
a_abs = os.path.abspath(a)
if os.path.commonpath([a_abs, git_root]) == git_root:
# adjust path relative to working dir inside git repo
a = ' ' + os.path.relpath(a_abs, os.path.join(git_root, script_dict['working_dir']))
argvs += ' {}'.format(a)
# update the script entry point to match the real argv and module call
script_dict['entry_point'] = '-m {}{}'.format(
vars(sys.modules['__main__'])['__package__'], (' ' + argvs) if argvs else '')
except Exception:
pass
return script_dict
2020-01-21 14:32:57 +00:00
@classmethod
def close(cls):
_JupyterObserver.close()
2019-06-10 17:00:28 +00:00
@attr.s
class ScriptInfoResult(object):
script = attr.ib(default=None)
warning_messages = attr.ib(factory=list)
auxiliary_git_diff = attr.ib(default=None)
class _JupyterHistoryLogger(object):
_reg_replace_ipython = r'\n([ \t]*)get_ipython\(\)'
_reg_replace_magic = r'\n([ \t]*)%'
_reg_replace_bang = r'\n([ \t]*)!'
def __init__(self):
self._exception_raised = False
self._cells_code = {}
self._counter = 0
self._ip = None
self._current_cell = None
# noinspection PyBroadException
try:
import re
self._replace_ipython_pattern = re.compile(self._reg_replace_ipython)
self._replace_magic_pattern = re.compile(self._reg_replace_magic)
self._replace_bang_pattern = re.compile(self._reg_replace_bang)
except Exception:
self._replace_ipython_pattern = None
self._replace_magic_pattern = None
self._replace_bang_pattern = None
def hook(self, ip=None):
if not ip:
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from IPython import get_ipython
except Exception:
return
self._ip = get_ipython()
else:
self._ip = ip
# noinspection PyBroadException
try:
# if this is colab, the callbacks do not contain the raw_cell content, so we have to patch it
if 'google.colab' in self._ip.extension_manager.loaded:
self._ip._org_run_cell = self._ip.run_cell
self._ip.run_cell = partial(self._patched_run_cell, self._ip)
2020-07-04 19:52:09 +00:00
except Exception:
pass
# start with the current history
self._initialize_history()
self._ip.events.register('post_run_cell', self._post_cell_callback)
self._ip.events.register('pre_run_cell', self._pre_cell_callback)
self._ip.set_custom_exc((Exception,), self._exception_callback)
def _patched_run_cell(self, shell, *args, **kwargs):
# noinspection PyBroadException
try:
raw_cell = kwargs.get('raw_cell') or args[0]
self._current_cell = raw_cell
except Exception:
pass
# noinspection PyProtectedMember
return shell._org_run_cell(*args, **kwargs)
def history(self, filename):
with open(filename, 'wt') as f:
for k, v in sorted(self._cells_code.items(), key=lambda p: p[0]):
f.write(v)
def history_to_str(self):
# return a pair: (history as str, current cell if we are in still in cell execution otherwise None)
return '\n'.join(v for k, v in sorted(self._cells_code.items(), key=lambda p: p[0])), self._current_cell
# noinspection PyUnusedLocal
def _exception_callback(self, shell, etype, value, tb, tb_offset=None):
self._exception_raised = True
return shell.showtraceback()
def _pre_cell_callback(self, *args, **_):
# noinspection PyBroadException
try:
if args:
self._current_cell = args[0].raw_cell
# we might have this value from somewhere else
if self._current_cell:
self._current_cell = self._conform_code(self._current_cell, replace_magic_bang=True)
except Exception:
pass
def _post_cell_callback(self, *_, **__):
# noinspection PyBroadException
try:
self._current_cell = None
if self._exception_raised:
# do nothing
self._exception_raised = False
return
self._exception_raised = False
# add the cell history
# noinspection PyBroadException
try:
cell_code = '\n' + self._ip.history_manager.input_hist_parsed[-1]
except Exception:
return
# fix magic / bang in code
cell_code = self._conform_code(cell_code)
self._cells_code[self._counter] = cell_code
self._counter += 1
except Exception:
pass
def _initialize_history(self):
# only once
if -1 in self._cells_code:
return
# noinspection PyBroadException
try:
cell_code = '\n' + '\n'.join(self._ip.history_manager.input_hist_parsed[:-1])
except Exception:
return
cell_code = self._conform_code(cell_code)
self._cells_code[-1] = cell_code
def _conform_code(self, cell_code, replace_magic_bang=False):
# fix magic / bang in code
if self._replace_ipython_pattern:
cell_code = self._replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', cell_code)
if replace_magic_bang and self._replace_magic_pattern and self._replace_bang_pattern:
cell_code = self._replace_magic_pattern.sub(r'\n# \g<1>%', cell_code)
cell_code = self._replace_bang_pattern.sub(r'\n# \g<1>!', cell_code)
return cell_code