Fix Google CoLab code/package detection

This commit is contained in:
allegroai 2020-06-13 22:12:28 +03:00
parent aa61fa3f06
commit 8a5f6b7d02

View File

@ -1,10 +1,10 @@
import os import os
import sys import sys
from copy import copy from copy import copy
from functools import partial
from tempfile import mkstemp from tempfile import mkstemp
import attr import attr
import collections
import logging import logging
import json import json
from furl import furl from furl import furl
@ -30,6 +30,7 @@ class ScriptRequirements(object):
self._root_folder = root_folder self._root_folder = root_folder
def get_requirements(self, entry_point_filename=None): def get_requirements(self, entry_point_filename=None):
# noinspection PyBroadException
try: try:
from ....utilities.pigar.reqs import get_installed_pkgs_detail from ....utilities.pigar.reqs import get_installed_pkgs_detail
from ....utilities.pigar.__main__ import GenerateReqs from ....utilities.pigar.__main__ import GenerateReqs
@ -48,18 +49,21 @@ class ScriptRequirements(object):
# hack: forcefully insert storage modules if we have them # hack: forcefully insert storage modules if we have them
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements,PyUnresolvedReferences
import boto3 import boto3
modules.add('boto3', 'trains.storage', 0) modules.add('boto3', 'trains.storage', 0)
except Exception: except Exception:
pass pass
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements,PyUnresolvedReferences
from google.cloud import storage from google.cloud import storage
modules.add('google_cloud_storage', 'trains.storage', 0) modules.add('google_cloud_storage', 'trains.storage', 0)
except Exception: except Exception:
pass pass
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements,PyUnresolvedReferences
from azure.storage.blob import ContentSettings from azure.storage.blob import ContentSettings
modules.add('azure_storage_blob', 'trains.storage', 0) modules.add('azure_storage_blob', 'trains.storage', 0)
except Exception: except Exception:
@ -77,7 +81,9 @@ class ScriptRequirements(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# see if this version of torch support tensorboard # see if this version of torch support tensorboard
# noinspection PyPackageRequirements,PyUnresolvedReferences
import torch.utils.tensorboard import torch.utils.tensorboard
# noinspection PyPackageRequirements,PyUnresolvedReferences
import tensorboard import tensorboard
modules.add('tensorboard', 'torch', 0) modules.add('tensorboard', 'torch', 0)
except Exception: except Exception:
@ -91,6 +97,7 @@ class ScriptRequirements(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from ..task import Task from ..task import Task
# noinspection PyProtectedMember
for package, version in Task._force_requirements.items(): for package, version in Task._force_requirements.items():
modules.add(package, 'trains', 0) modules.add(package, 'trains', 0)
except Exception: except Exception:
@ -101,6 +108,7 @@ class ScriptRequirements(object):
@staticmethod @staticmethod
def create_requirements_txt(reqs, local_pks=None): def create_requirements_txt(reqs, local_pks=None):
# write requirements.txt # write requirements.txt
# noinspection PyBroadException
try: try:
conda_requirements = '' conda_requirements = ''
conda_prefix = os.environ.get('CONDA_PREFIX') conda_prefix = os.environ.get('CONDA_PREFIX')
@ -120,15 +128,16 @@ class ScriptRequirements(object):
if name == 'pytorch': if name == 'pytorch':
name = 'torch' name = 'torch'
k, v = reqs_lower.get(name, (None, None)) k, v = reqs_lower.get(name, (None, None))
if k: if k and v is not None:
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version) conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
except: except Exception:
conda_requirements = '' conda_requirements = ''
# add forced requirements: # add forced requirements:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from ..task import Task from ..task import Task
# noinspection PyProtectedMember
forced_packages = copy(Task._force_requirements) forced_packages = copy(Task._force_requirements)
except Exception: except Exception:
forced_packages = {} forced_packages = {}
@ -198,15 +207,20 @@ class _JupyterObserver(object):
_sync_event = Event() _sync_event = Event()
_sample_frequency = 30. _sample_frequency = 30.
_first_sample_frequency = 3. _first_sample_frequency = 3.
_jupyter_history_logger = None
@classmethod @classmethod
def observer(cls, jupyter_notebook_filename): def observer(cls, jupyter_notebook_filename, log_history):
if cls._thread is not None: if cls._thread is not None:
# order of signaling is important! # order of signaling is important!
cls._exit_event.set() cls._exit_event.set()
cls._sync_event.set() cls._sync_event.set()
cls._thread.join() cls._thread.join()
if log_history and cls._jupyter_history_logger is None:
cls._jupyter_history_logger = _JupyterHistoryLogger()
cls._jupyter_history_logger.hook()
cls._sync_event.clear() cls._sync_event.clear()
cls._exit_event.clear() cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
@ -214,7 +228,7 @@ class _JupyterObserver(object):
cls._thread.start() cls._thread.start()
@classmethod @classmethod
def signal_sync(cls, *_): def signal_sync(cls, *_, **__):
cls._sync_event.set() cls._sync_event.set()
@classmethod @classmethod
@ -233,6 +247,7 @@ class _JupyterObserver(object):
# load jupyter notebook package # load jupyter notebook package
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter() _script_exporter = ScriptExporter()
except Exception: except Exception:
@ -249,6 +264,7 @@ class _JupyterObserver(object):
# load IPython # load IPython
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements
from IPython import get_ipython from IPython import get_ipython
except Exception: except Exception:
# should not happen # should not happen
@ -266,16 +282,18 @@ class _JupyterObserver(object):
counter = 0 counter = 0
prev_script_hash = None prev_script_hash = None
# noinspection PyBroadException
try: try:
from ....version import __version__ from ....version import __version__
our_module = cls.__module__.split('.')[0], __version__ our_module = cls.__module__.split('.')[0], __version__
except: except Exception:
our_module = None our_module = None
# noinspection PyBroadException
try: try:
import re import re
replace_ipython_pattern = re.compile('\\n([ \\t]*)get_ipython\(\)') replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)')
except: except Exception:
replace_ipython_pattern = None replace_ipython_pattern = None
# main observer loop, check if we need to exit # main observer loop, check if we need to exit
@ -292,6 +310,9 @@ class _JupyterObserver(object):
if not task: if not task:
continue continue
script_code = None
fmodules = None
current_cell = None
# if we have a local file: # if we have a local file:
if notebook: if notebook:
if not notebook.exists(): if not notebook.exists():
@ -302,35 +323,67 @@ class _JupyterObserver(object):
last_update_ts = notebook.stat().st_mtime last_update_ts = notebook.stat().st_mtime
else: else:
# serialize notebook to a temp file # serialize notebook to a temp file
if cls._jupyter_history_logger:
script_code, current_cell = cls._jupyter_history_logger.history_to_str()
else:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
get_ipython().run_line_magic('notebook', local_jupyter_filename) # noinspection PyBroadException
except Exception as ex: try:
os.unlink(local_jupyter_filename)
except Exception:
pass
get_ipython().run_line_magic('history', '-t -f {}'.format(local_jupyter_filename))
with open(local_jupyter_filename, 'r') as f:
script_code = f.read()
# load the modules
from ....utilities.pigar.modules import ImportedModules
fmodules = ImportedModules()
for nm in set([str(m).split('.')[0] for m in sys.modules]):
fmodules.add(nm, 'notebook', 0)
except Exception:
continue continue
# get notebook python script # get notebook python script
script_code, resources = _script_exporter.from_filename(local_jupyter_filename) if script_code is None:
current_script_hash = hash(script_code) script_code, _ = _script_exporter.from_filename(local_jupyter_filename)
current_script_hash = hash(script_code + (current_cell or ''))
if prev_script_hash and prev_script_hash == current_script_hash: if prev_script_hash and prev_script_hash == current_script_hash:
continue continue
# remove ipython direct access from the script code # remove ipython direct access from the script code
# we will not be able to run them anyhow # we will not be able to run them anyhow
if replace_ipython_pattern: if replace_ipython_pattern:
script_code = replace_ipython_pattern.sub('\n# \g<1>get_ipython()', script_code) script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
requirements_txt = '' requirements_txt = ''
conda_requirements = '' conda_requirements = ''
# parse jupyter python script and prepare pip requirements (pigar) # parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements # if backend supports requirements
if file_import_modules and Session.check_min_api_version('2.2'): if file_import_modules and Session.check_min_api_version('2.2'):
fmodules, _ = file_import_modules(notebook.parts[-1], script_code) if fmodules is None:
fmodules, _ = file_import_modules(
notebook.parts[-1] if notebook else 'notebook', script_code)
if current_cell:
cell_fmodules, _ = file_import_modules(
notebook.parts[-1] if notebook else 'notebook', current_cell)
# noinspection PyBroadException
try:
fmodules |= cell_fmodules
except Exception:
pass
# add current cell to the script
if current_cell:
script_code += '\n' + current_cell
fmodules = ScriptRequirements.add_trains_used_packages(fmodules) fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
# noinspection PyUnboundLocalVariable
installed_pkgs = get_installed_pkgs_detail() installed_pkgs = get_installed_pkgs_detail()
# make sure we are in installed packages # make sure we are in installed packages
if our_module and (our_module[0] not in installed_pkgs): if our_module and (our_module[0] not in installed_pkgs):
installed_pkgs[our_module[0]] = our_module installed_pkgs[our_module[0]] = our_module
# noinspection PyUnboundLocalVariable
reqs = ReqsModules() reqs = ReqsModules()
for name in fmodules: for name in fmodules:
if name in installed_pkgs: if name in installed_pkgs:
@ -343,8 +396,10 @@ class _JupyterObserver(object):
data_script = task.data.script data_script = task.data.script
data_script.diff = script_code data_script.diff = script_code
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements} data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
# noinspection PyProtectedMember
task._update_script(script=data_script) task._update_script(script=data_script)
# update requirements # update requirements
# noinspection PyProtectedMember
task._update_requirements(requirements=requirements_txt) task._update_requirements(requirements=requirements_txt)
except Exception: except Exception:
pass pass
@ -356,14 +411,17 @@ class ScriptInfo(object):
""" Script info detection plugins, in order of priority """ """ Script info detection plugins, in order of priority """
@classmethod @classmethod
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename): def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if 'IPython' in sys.modules: if 'IPython' in sys.modules:
# noinspection PyPackageRequirements
from IPython import get_ipython from IPython import get_ipython
if get_ipython(): if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename) _JupyterObserver.observer(jupyter_notebook_filename, log_history)
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
if log_history:
get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync)
except Exception: except Exception:
pass pass
@ -377,22 +435,26 @@ class ScriptInfo(object):
# we can safely assume that we can import the notebook package here # we can safely assume that we can import the notebook package here
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements
from notebook.notebookapp import list_running_servers from notebook.notebookapp import list_running_servers
import requests import requests
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '') current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
# noinspection PyBroadException
try: try:
server_info = next(list_running_servers()) server_info = next(list_running_servers())
except Exception: except Exception:
# on some jupyter notebook versions this function can crash on parsing the json file, # on some jupyter notebook versions this function can crash on parsing the json file,
# we will parse it manually here # we will parse it manually here
# noinspection PyPackageRequirements
import ipykernel import ipykernel
from glob import glob from glob import glob
import json import json
for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), 'nbserver-*.json')): for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), 'nbserver-*.json')):
# noinspection PyBroadException
try: try:
with open(f, 'r') as json_data: with open(f, 'r') as json_data:
server_info = json.load(json_data) server_info = json.load(json_data)
except: except Exception:
server_info = None server_info = None
if server_info: if server_info:
break break
@ -403,6 +465,7 @@ class ScriptInfo(object):
except requests.exceptions.SSLError: except requests.exceptions.SSLError:
# disable SSL check warning # disable SSL check warning
from urllib3.exceptions import InsecureRequestWarning from urllib3.exceptions import InsecureRequestWarning
# noinspection PyUnresolvedReferences
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
# fire request # fire request
r = requests.get( r = requests.get(
@ -428,6 +491,7 @@ class ScriptInfo(object):
# check if this is google.colab, then there is no local file # check if this is google.colab, then there is no local file
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements
from IPython import get_ipython from IPython import get_ipython
if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded: if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded:
is_google_colab = True is_google_colab = True
@ -435,7 +499,10 @@ class ScriptInfo(object):
pass pass
if is_google_colab: if is_google_colab:
script_entry_point = notebook_name script_entry_point = str(notebook_name or 'notebook').replace(
'>', '_').replace('<', '_').replace('.ipynb', '.py')
if not script_entry_point.lower().endswith('.py'):
script_entry_point += '.py'
local_ipynb_file = None local_ipynb_file = None
else: else:
# always slash, because this is from uri (so never backslash not even oon windows) # always slash, because this is from uri (so never backslash not even oon windows)
@ -457,7 +524,7 @@ class ScriptInfo(object):
# install the post store hook, # install the post store hook,
# notice that if we do not have a local file we serialize/write every time the entire notebook # notice that if we do not have a local file we serialize/write every time the entire notebook
cls._jupyter_install_post_store_hook(local_ipynb_file) cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab)
return script_entry_point return script_entry_point
except Exception: except Exception:
@ -641,3 +708,135 @@ class ScriptInfo(object):
class ScriptInfoResult(object): class ScriptInfoResult(object):
script = attr.ib(default=None) script = attr.ib(default=None)
warning_messages = attr.ib(factory=list) warning_messages = attr.ib(factory=list)
class _JupyterHistoryLogger(object):
_reg_replace_ipython = r'\n([ \t]*)get_ipython\(\)'
_reg_replace_magic = r'\n([ \t]*)%'
_reg_replace_bang = r'\n([ \t]*)!'
def __init__(self):
self._exception_raised = False
self._cells_code = {}
self._counter = 0
self._ip = None
self._current_cell = None
# noinspection PyBroadException
try:
import re
self._replace_ipython_pattern = re.compile(self._reg_replace_ipython)
self._replace_magic_pattern = re.compile(self._reg_replace_magic)
self._replace_bang_pattern = re.compile(self._reg_replace_bang)
except Exception:
self._replace_ipython_pattern = None
self._replace_magic_pattern = None
self._replace_bang_pattern = None
def hook(self, ip=None):
if not ip:
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from IPython import get_ipython
except Exception:
return
self._ip = get_ipython()
else:
self._ip = ip
# noinspection PyBroadException
try:
# if this is colab, the callbacks do not contain the raw_cell content, so we have to patch it
if 'google.colab' in self._ip.extension_manager.loaded:
self._ip._org_run_cell = self._ip.run_cell
self._ip.run_cell = partial(self._patched_run_cell, self._ip)
except Exception as ex:
pass
# start with the current history
self._initialize_history()
self._ip.events.register('post_run_cell', self._post_cell_callback)
self._ip.events.register('pre_run_cell', self._pre_cell_callback)
self._ip.set_custom_exc((Exception,), self._exception_callback)
def _patched_run_cell(self, shell, *args, **kwargs):
# noinspection PyBroadException
try:
raw_cell = kwargs.get('raw_cell') or args[0]
self._current_cell = raw_cell
except Exception:
pass
# noinspection PyProtectedMember
return shell._org_run_cell(*args, **kwargs)
def history(self, filename):
with open(filename, 'wt') as f:
for k, v in sorted(self._cells_code.items(), key=lambda p: p[0]):
f.write(v)
def history_to_str(self):
# return a pair: (history as str, current cell if we are in still in cell execution otherwise None)
return '\n'.join(v for k, v in sorted(self._cells_code.items(), key=lambda p: p[0])), self._current_cell
# noinspection PyUnusedLocal
def _exception_callback(self, shell, etype, value, tb, tb_offset=None):
self._exception_raised = True
return shell.showtraceback()
def _pre_cell_callback(self, *args, **_):
# noinspection PyBroadException
try:
if args:
self._current_cell = args[0].raw_cell
# we might have this value from somewhere else
if self._current_cell:
self._current_cell = self._conform_code(self._current_cell, replace_magic_bang=True)
except Exception:
pass
def _post_cell_callback(self, *_, **__):
# noinspection PyBroadException
try:
self._current_cell = None
if self._exception_raised:
# do nothing
self._exception_raised = False
return
self._exception_raised = False
# add the cell history
# noinspection PyBroadException
try:
cell_code = '\n' + self._ip.history_manager.input_hist_parsed[-1]
except Exception:
return
# fix magic / bang in code
cell_code = self._conform_code(cell_code)
self._cells_code[self._counter] = cell_code
self._counter += 1
except Exception:
pass
def _initialize_history(self):
# only once
if -1 in self._cells_code:
return
# noinspection PyBroadException
try:
cell_code = '\n' + '\n'.join(self._ip.history_manager.input_hist_parsed[:-1])
except Exception:
return
cell_code = self._conform_code(cell_code)
self._cells_code[-1] = cell_code
def _conform_code(self, cell_code, replace_magic_bang=False):
# fix magic / bang in code
if self._replace_ipython_pattern:
cell_code = self._replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', cell_code)
if replace_magic_bang and self._replace_magic_pattern and self._replace_bang_pattern:
cell_code = self._replace_magic_pattern.sub(r'\n# \g<1>%', cell_code)
cell_code = self._replace_bang_pattern.sub(r'\n# \g<1>!', cell_code)
return cell_code