Add support for Azure notebook and google colab

This commit is contained in:
allegroai 2019-07-13 23:55:34 +03:00
parent 7d0bf4838e
commit 9c7e0747fb

View File

@ -1,5 +1,6 @@
import os
import sys
from tempfile import mkstemp
import attr
import collections
@ -119,20 +120,28 @@ class ScriptRequirements(object):
class _JupyterObserver(object):
_thread = None
_exit_event = Event()
_sync_event = Event()
_sample_frequency = 30.
_first_sample_frequency = 3.
@classmethod
def observer(cls, jupyter_notebook_filename):
if cls._thread is not None:
# order of signaling is important!
cls._exit_event.set()
cls._sync_event.set()
cls._thread.join()
cls._sync_event.clear()
cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
cls._thread.daemon = True
cls._thread.start()
@classmethod
def signal_sync(cls, *_):
cls._sync_event.set()
@classmethod
def _daemon(cls, jupyter_notebook_filename):
from trains import Task
@ -153,28 +162,59 @@ class _JupyterObserver(object):
logger.setLevel(logging.WARNING)
except Exception:
file_import_modules = None
# main observer loop
notebook = Path(jupyter_notebook_filename)
# load IPython
# noinspection PyBroadException
try:
from IPython import get_ipython
except Exception:
# should not happen
get_ipython = None
# setup local notebook files
if jupyter_notebook_filename:
notebook = Path(jupyter_notebook_filename)
local_jupyter_filename = jupyter_notebook_filename
else:
notebook = None
fd, local_jupyter_filename = mkstemp(suffix='.ipynb')
os.close(fd)
last_update_ts = None
counter = 0
prev_script_hash = None
# main observer loop
while True:
if cls._exit_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency):
# wait for timeout or sync event
cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency)
# check if we need to exit
if cls._exit_event.wait(timeout=0.):
return
cls._sync_event.clear()
counter += 1
# noinspection PyBroadException
try:
if not notebook.exists():
continue
# check if notebook changed
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue
last_update_ts = notebook.stat().st_mtime
# if there is no task connected, do nothing
task = Task.current_task()
if not task:
continue
# if we have a local file:
if notebook:
if not notebook.exists():
continue
# check if notebook changed
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue
last_update_ts = notebook.stat().st_mtime
else:
# serialize notebook to a temp file
# noinspection PyBroadException
try:
get_ipython().run_line_magic('notebook', local_jupyter_filename)
except Exception as ex:
continue
# get notebook python script
script_code, resources = _script_exporter.from_filename(jupyter_notebook_filename)
script_code, resources = _script_exporter.from_filename(local_jupyter_filename)
current_script_hash = hash(script_code)
if prev_script_hash and prev_script_hash == current_script_hash:
continue
@ -198,8 +238,7 @@ class _JupyterObserver(object):
data_script.requirements = {'pip': requirements_txt}
task._update_script(script=data_script)
# update requirements
if requirements_txt:
task._update_requirements(requirements=requirements_txt)
task._update_requirements(requirements=requirements_txt)
except Exception:
pass
@ -217,12 +256,15 @@ class ScriptInfo(object):
from IPython import get_ipython
if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename)
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
except Exception:
pass
@classmethod
def _get_jupyter_notebook_filename(cls):
if not sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
if not (sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or
sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \
or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
return None
# we can safely assume that we can import the notebook package here
@ -244,23 +286,45 @@ class ScriptInfo(object):
cur_notebook = n
break
notebook_path = cur_notebook['notebook']['path']
# always slash, because this is from uri (so never backslash not even oon windows)
entry_point_filename = notebook_path.split('/')[-1]
notebook_path = cur_notebook['notebook'].get('path', '')
notebook_name = cur_notebook['notebook'].get('name', '')
# now we should try to find the actual file
entry_point = (Path.cwd() / entry_point_filename).absolute()
if not entry_point.is_file():
entry_point = (Path.cwd() / notebook_path).absolute()
is_google_colab = False
# check if this is google.colab, then there is no local file
# noinspection PyBroadException
try:
from IPython import get_ipython
if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded:
is_google_colab = True
except Exception:
pass
# install the post store hook, so always have a synced file in the system
cls._jupyter_install_post_store_hook(entry_point.as_posix())
if is_google_colab:
script_entry_point = notebook_name
local_ipynb_file = None
else:
# always slash, because this is from uri (so never backslash not even oon windows)
entry_point_filename = notebook_path.split('/')[-1]
# now replace the .ipynb with .py
# we assume we will have that file available with the Jupyter notebook plugin
entry_point = entry_point.with_suffix('.py')
# now we should try to find the actual file
entry_point = (Path.cwd() / entry_point_filename).absolute()
if not entry_point.is_file():
entry_point = (Path.cwd() / notebook_path).absolute()
return entry_point.as_posix()
# get local ipynb for observer
local_ipynb_file = entry_point.as_posix()
# now replace the .ipynb with .py
# we assume we will have that file available with the Jupyter notebook plugin
entry_point = entry_point.with_suffix('.py')
script_entry_point = entry_point.as_posix()
# install the post store hook,
# notice that if we do not have a local file we serialize/write every time the entire notebook
cls._jupyter_install_post_store_hook(local_ipynb_file)
return script_entry_point
except Exception:
return None