Improve CoLab integration (store entire colab, not history)

This commit is contained in:
allegroai 2022-10-30 19:25:15 +02:00
parent 89b675c267
commit 16fb59c33f

View File

@ -3,7 +3,7 @@ import sys
from copy import copy
from datetime import datetime
from functools import partial
from tempfile import mkstemp, gettempdir
from tempfile import gettempdir, mkdtemp
import attr
import logging
@ -273,7 +273,7 @@ class _JupyterObserver(object):
return get_logger("Repository Detection")
@classmethod
def observer(cls, jupyter_notebook_filename, log_history):
def observer(cls, jupyter_notebook_filename, notebook_name=None, log_history=False):
if cls._thread is not None:
# order of signaling is important!
cls._exit_event.set()
@ -286,7 +286,7 @@ class _JupyterObserver(object):
cls._sync_event.clear()
cls._exit_event.clear()
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, notebook_name))
cls._thread.daemon = True
cls._thread.start()
@ -304,15 +304,25 @@ class _JupyterObserver(object):
cls._thread = None
@classmethod
def _daemon(cls, jupyter_notebook_filename):
def _daemon(cls, jupyter_notebook_filename, notebook_name=None):
from clearml import Task
# load jupyter notebook package
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter()
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from nbconvert.exporters import PythonExporter
_script_exporter = PythonExporter()
except Exception:
_script_exporter = None
if _script_exporter is None:
# noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter
_script_exporter = ScriptExporter()
except Exception as ex:
cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex))
_script_exporter = None
@ -341,9 +351,15 @@ class _JupyterObserver(object):
local_jupyter_filename = jupyter_notebook_filename
else:
notebook = None
fd, local_jupyter_filename = mkstemp(suffix='.ipynb')
os.close(fd)
folder = mkdtemp(suffix='.notebook')
if notebook_name.endswith(".py"):
notebook_name = notebook_name.replace(".py", ".ipynb")
if not notebook_name.endswith(".ipynb"):
notebook_name += ".ipynb"
local_jupyter_filename = Path(folder) / notebook_name
last_update_ts = None
last_colab_hash = None
counter = 0
prev_script_hash = None
@ -357,7 +373,7 @@ class _JupyterObserver(object):
# noinspection PyBroadException
try:
import re
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)')
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\([ \t]*\)')
replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(')
except Exception:
replace_ipython_pattern = None
@ -388,6 +404,20 @@ class _JupyterObserver(object):
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
continue
last_update_ts = notebook.stat().st_mtime
elif notebook_name:
# this is a colab, let's try to get the notebook
# noinspection PyProtectedMember
colab_name, colab_notebook = ScriptInfo._get_colab_notebook()
if colab_notebook:
current_colab_hash = hash(colab_notebook)
if current_colab_hash == last_colab_hash:
continue
last_colab_hash = current_colab_hash
with open(local_jupyter_filename.as_posix(), "wt") as f:
f.write(colab_notebook)
else:
# something went wrong we will try again later
continue
else:
# serialize notebook to a temp file
if cls._jupyter_history_logger:
@ -449,13 +479,6 @@ class _JupyterObserver(object):
if prev_script_hash and prev_script_hash == current_script_hash:
continue
# remove ipython direct access from the script code
# we will not be able to run them anyhow
if replace_ipython_pattern:
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
if replace_ipython_display_pattern:
script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code)
requirements_txt = ''
conda_requirements = ''
# parse jupyter python script and prepare pip requirements (pigar)
@ -490,6 +513,14 @@ class _JupyterObserver(object):
reqs.add(pkg_name, version, fmodules[name])
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
# remove ipython direct access from the script code
# we will not be able to run them anyhow
# probably should be better dealt with, because multi line will break it
if replace_ipython_pattern:
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
if replace_ipython_display_pattern:
script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code)
# update script
prev_script_hash = current_script_hash
data_script = task.data.script
@ -515,14 +546,15 @@ class ScriptInfo(object):
return get_logger("Repository Detection")
@classmethod
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False):
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, notebook_name=None, log_history=False):
# noinspection PyBroadException
try:
if 'IPython' in sys.modules:
# noinspection PyPackageRequirements
from IPython import get_ipython
if get_ipython():
_JupyterObserver.observer(jupyter_notebook_filename, log_history)
_JupyterObserver.observer(
jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history)
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
if log_history:
get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync)
@ -662,6 +694,8 @@ class ScriptInfo(object):
break
is_google_colab = False
log_history = False
colab_name = None
# check if this is google.colab, then there is no local file
# noinspection PyBroadException
try:
@ -673,8 +707,15 @@ class ScriptInfo(object):
pass
if is_google_colab:
# check if we can get the notebook
colab_name, colab_notebook = cls._get_colab_notebook()
if colab_name is not None:
notebook_name = colab_name
log_history = False
script_entry_point = str(notebook_name or 'notebook').replace(
'>', '_').replace('<', '_').replace('.ipynb', '.py')
if not script_entry_point.lower().endswith('.py'):
script_entry_point += '.py'
local_ipynb_file = None
@ -724,12 +765,29 @@ class ScriptInfo(object):
# install the post store hook,
# notice that if we do not have a local file we serialize/write every time the entire notebook
cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab)
cls._jupyter_install_post_store_hook(local_ipynb_file, notebook_name=colab_name, log_history=log_history)
return script_entry_point
except Exception:
return None
@classmethod
def _get_colab_notebook(cls, timeout=30):
# returns tuple (notebook name, raw string notebook)
# None, None if fails
try:
from google.colab import _message
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
if not notebook_name.endswith(".ipynb"):
notebook_name += ".ipynb"
# encoding to json
return notebook_name, json.dumps(notebook)
except: # noqa
return None, None
@classmethod
def _get_entry_point(cls, repo_root, script_path):
repo_root = Path(repo_root).absolute()