Improve CoLab integration (store entire colab, not history)

This commit is contained in:
allegroai 2022-10-30 19:25:15 +02:00
parent 89b675c267
commit 16fb59c33f

View File

@ -3,7 +3,7 @@ import sys
from copy import copy from copy import copy
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from tempfile import mkstemp, gettempdir from tempfile import gettempdir, mkdtemp
import attr import attr
import logging import logging
@ -273,7 +273,7 @@ class _JupyterObserver(object):
return get_logger("Repository Detection") return get_logger("Repository Detection")
@classmethod @classmethod
def observer(cls, jupyter_notebook_filename, log_history): def observer(cls, jupyter_notebook_filename, notebook_name=None, log_history=False):
if cls._thread is not None: if cls._thread is not None:
# order of signaling is important! # order of signaling is important!
cls._exit_event.set() cls._exit_event.set()
@ -286,7 +286,7 @@ class _JupyterObserver(object):
cls._sync_event.clear() cls._sync_event.clear()
cls._exit_event.clear() cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, notebook_name))
cls._thread.daemon = True cls._thread.daemon = True
cls._thread.start() cls._thread.start()
@ -304,15 +304,25 @@ class _JupyterObserver(object):
cls._thread = None cls._thread = None
@classmethod @classmethod
def _daemon(cls, jupyter_notebook_filename): def _daemon(cls, jupyter_notebook_filename, notebook_name=None):
from clearml import Task from clearml import Task
# load jupyter notebook package # load jupyter notebook package
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from nbconvert.exporters import PythonExporter
_script_exporter = PythonExporter()
except Exception:
_script_exporter = None
if _script_exporter is None:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter() _script_exporter = ScriptExporter()
except Exception as ex: except Exception as ex:
cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex)) cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex))
_script_exporter = None _script_exporter = None
@ -341,9 +351,15 @@ class _JupyterObserver(object):
local_jupyter_filename = jupyter_notebook_filename local_jupyter_filename = jupyter_notebook_filename
else: else:
notebook = None notebook = None
fd, local_jupyter_filename = mkstemp(suffix='.ipynb') folder = mkdtemp(suffix='.notebook')
os.close(fd) if notebook_name.endswith(".py"):
notebook_name = notebook_name.replace(".py", ".ipynb")
if not notebook_name.endswith(".ipynb"):
notebook_name += ".ipynb"
local_jupyter_filename = Path(folder) / notebook_name
last_update_ts = None last_update_ts = None
last_colab_hash = None
counter = 0 counter = 0
prev_script_hash = None prev_script_hash = None
@ -357,7 +373,7 @@ class _JupyterObserver(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import re import re
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)') replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\([ \t]*\)')
replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(') replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(')
except Exception: except Exception:
replace_ipython_pattern = None replace_ipython_pattern = None
@ -388,6 +404,20 @@ class _JupyterObserver(object):
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue continue
last_update_ts = notebook.stat().st_mtime last_update_ts = notebook.stat().st_mtime
elif notebook_name:
# this is a colab, let's try to get the notebook
# noinspection PyProtectedMember
colab_name, colab_notebook = ScriptInfo._get_colab_notebook()
if colab_notebook:
current_colab_hash = hash(colab_notebook)
if current_colab_hash == last_colab_hash:
continue
last_colab_hash = current_colab_hash
with open(local_jupyter_filename.as_posix(), "wt") as f:
f.write(colab_notebook)
else:
# something went wrong we will try again later
continue
else: else:
# serialize notebook to a temp file # serialize notebook to a temp file
if cls._jupyter_history_logger: if cls._jupyter_history_logger:
@ -449,13 +479,6 @@ class _JupyterObserver(object):
if prev_script_hash and prev_script_hash == current_script_hash: if prev_script_hash and prev_script_hash == current_script_hash:
continue continue
# remove ipython direct access from the script code
# we will not be able to run them anyhow
if replace_ipython_pattern:
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
if replace_ipython_display_pattern:
script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code)
requirements_txt = '' requirements_txt = ''
conda_requirements = '' conda_requirements = ''
# parse jupyter python script and prepare pip requirements (pigar) # parse jupyter python script and prepare pip requirements (pigar)
@ -490,6 +513,14 @@ class _JupyterObserver(object):
reqs.add(pkg_name, version, fmodules[name]) reqs.add(pkg_name, version, fmodules[name])
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs) requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
# remove ipython direct access from the script code
# we will not be able to run them anyhow
# probably should be better dealt with, because multi line will break it
if replace_ipython_pattern:
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
if replace_ipython_display_pattern:
script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code)
# update script # update script
prev_script_hash = current_script_hash prev_script_hash = current_script_hash
data_script = task.data.script data_script = task.data.script
@ -515,14 +546,15 @@ class ScriptInfo(object):
return get_logger("Repository Detection") return get_logger("Repository Detection")
@classmethod @classmethod
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False): def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, notebook_name=None, log_history=False):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if 'IPython' in sys.modules: if 'IPython' in sys.modules:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from IPython import get_ipython from IPython import get_ipython
if get_ipython(): if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename, log_history) _JupyterObserver.observer(
jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history)
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
if log_history: if log_history:
get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync) get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync)
@ -662,6 +694,8 @@ class ScriptInfo(object):
break break
is_google_colab = False is_google_colab = False
log_history = False
colab_name = None
# check if this is google.colab, then there is no local file # check if this is google.colab, then there is no local file
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -673,8 +707,15 @@ class ScriptInfo(object):
pass pass
if is_google_colab: if is_google_colab:
# check if we can get the notebook
colab_name, colab_notebook = cls._get_colab_notebook()
if colab_name is not None:
notebook_name = colab_name
log_history = False
script_entry_point = str(notebook_name or 'notebook').replace( script_entry_point = str(notebook_name or 'notebook').replace(
'>', '_').replace('<', '_').replace('.ipynb', '.py') '>', '_').replace('<', '_').replace('.ipynb', '.py')
if not script_entry_point.lower().endswith('.py'): if not script_entry_point.lower().endswith('.py'):
script_entry_point += '.py' script_entry_point += '.py'
local_ipynb_file = None local_ipynb_file = None
@ -724,12 +765,29 @@ class ScriptInfo(object):
# install the post store hook, # install the post store hook,
# notice that if we do not have a local file we serialize/write every time the entire notebook # notice that if we do not have a local file we serialize/write every time the entire notebook
cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab) cls._jupyter_install_post_store_hook(local_ipynb_file, notebook_name=colab_name, log_history=log_history)
return script_entry_point return script_entry_point
except Exception: except Exception:
return None return None
@classmethod
def _get_colab_notebook(cls, timeout=30):
# returns tuple (notebook name, raw string notebook)
# None, None if fails
try:
from google.colab import _message
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
if not notebook_name.endswith(".ipynb"):
notebook_name += ".ipynb"
# encoding to json
return notebook_name, json.dumps(notebook)
except: # noqa
return None, None
@classmethod @classmethod
def _get_entry_point(cls, repo_root, script_path): def _get_entry_point(cls, repo_root, script_path):
repo_root = Path(repo_root).absolute() repo_root = Path(repo_root).absolute()