From 16fb59c33f24b6f1bf0d7c08867e258d596b24c8 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 30 Oct 2022 19:25:15 +0200 Subject: [PATCH] Improve CoLab integration (store entire colab, not history) --- .../backend_interface/task/repo/scriptinfo.py | 98 +++++++++++++++---- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/clearml/backend_interface/task/repo/scriptinfo.py b/clearml/backend_interface/task/repo/scriptinfo.py index e6c0bd48..30c1004a 100644 --- a/clearml/backend_interface/task/repo/scriptinfo.py +++ b/clearml/backend_interface/task/repo/scriptinfo.py @@ -3,7 +3,7 @@ import sys from copy import copy from datetime import datetime from functools import partial -from tempfile import mkstemp, gettempdir +from tempfile import gettempdir, mkdtemp import attr import logging @@ -273,7 +273,7 @@ class _JupyterObserver(object): return get_logger("Repository Detection") @classmethod - def observer(cls, jupyter_notebook_filename, log_history): + def observer(cls, jupyter_notebook_filename, notebook_name=None, log_history=False): if cls._thread is not None: # order of signaling is important! cls._exit_event.set() @@ -286,7 +286,7 @@ class _JupyterObserver(object): cls._sync_event.clear() cls._exit_event.clear() - cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) + cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, notebook_name)) cls._thread.daemon = True cls._thread.start() @@ -304,15 +304,25 @@ class _JupyterObserver(object): cls._thread = None @classmethod - def _daemon(cls, jupyter_notebook_filename): + def _daemon(cls, jupyter_notebook_filename, notebook_name=None): from clearml import Task # load jupyter notebook package # noinspection PyBroadException try: - # noinspection PyPackageRequirements - from nbconvert.exporters.script import ScriptExporter - _script_exporter = ScriptExporter() + # noinspection PyBroadException + try: + # noinspection PyPackageRequirements + from nbconvert.exporters import PythonExporter + _script_exporter = PythonExporter() + except Exception: + _script_exporter = None + + if _script_exporter is None: + # noinspection PyPackageRequirements + from nbconvert.exporters.script import ScriptExporter + _script_exporter = ScriptExporter() + except Exception as ex: cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex)) _script_exporter = None @@ -341,9 +351,15 @@ class _JupyterObserver(object): local_jupyter_filename = jupyter_notebook_filename else: notebook = None - fd, local_jupyter_filename = mkstemp(suffix='.ipynb') - os.close(fd) + folder = mkdtemp(suffix='.notebook') + if notebook_name.endswith(".py"): + notebook_name = notebook_name.replace(".py", ".ipynb") + if not notebook_name.endswith(".ipynb"): + notebook_name += ".ipynb" + local_jupyter_filename = Path(folder) / notebook_name + last_update_ts = None + last_colab_hash = None counter = 0 prev_script_hash = None @@ -357,7 +373,7 @@ class _JupyterObserver(object): # noinspection PyBroadException try: import re - replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)') + replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\([ \t]*\)') replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(') except Exception: replace_ipython_pattern = None @@ -388,6 +404,20 @@ class _JupyterObserver(object): if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: continue last_update_ts = notebook.stat().st_mtime + elif notebook_name: + # this is a colab, let's try to get the notebook + # noinspection PyProtectedMember + colab_name, colab_notebook = ScriptInfo._get_colab_notebook() + if colab_notebook: + current_colab_hash = hash(colab_notebook) + if current_colab_hash == last_colab_hash: + continue + last_colab_hash = current_colab_hash + with open(local_jupyter_filename.as_posix(), "wt") as f: + f.write(colab_notebook) + else: + # something went wrong we will try again later + continue else: # serialize notebook to a temp file if cls._jupyter_history_logger: @@ -449,13 +479,6 @@ class _JupyterObserver(object): if prev_script_hash and prev_script_hash == current_script_hash: continue - # remove ipython direct access from the script code - # we will not be able to run them anyhow - if replace_ipython_pattern: - script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code) - if replace_ipython_display_pattern: - script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code) - requirements_txt = '' conda_requirements = '' # parse jupyter python script and prepare pip requirements (pigar) @@ -490,6 +513,14 @@ class _JupyterObserver(object): reqs.add(pkg_name, version, fmodules[name]) requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs) + # remove ipython direct access from the script code + # we will not be able to run them anyhow + # probably should be better dealt with, because multi line will break it + if replace_ipython_pattern: + script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code) + if replace_ipython_display_pattern: + script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code) + # update script prev_script_hash = current_script_hash data_script = task.data.script @@ -515,14 +546,15 @@ class ScriptInfo(object): return get_logger("Repository Detection") @classmethod - def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False): + def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, notebook_name=None, log_history=False): # noinspection PyBroadException try: if 'IPython' in sys.modules: # noinspection PyPackageRequirements from IPython import get_ipython if get_ipython(): - _JupyterObserver.observer(jupyter_notebook_filename, log_history) + _JupyterObserver.observer( + jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history) get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) if log_history: get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync) @@ -662,6 +694,8 @@ class ScriptInfo(object): break is_google_colab = False + log_history = False + colab_name = None # check if this is google.colab, then there is no local file # noinspection PyBroadException try: @@ -673,8 +707,15 @@ class ScriptInfo(object): pass if is_google_colab: + # check if we can get the notebook + colab_name, colab_notebook = cls._get_colab_notebook() + if colab_name is not None: + notebook_name = colab_name + log_history = False + script_entry_point = str(notebook_name or 'notebook').replace( '>', '_').replace('<', '_').replace('.ipynb', '.py') + if not script_entry_point.lower().endswith('.py'): script_entry_point += '.py' local_ipynb_file = None @@ -724,12 +765,29 @@ class ScriptInfo(object): # install the post store hook, # notice that if we do not have a local file we serialize/write every time the entire notebook - cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab) + cls._jupyter_install_post_store_hook(local_ipynb_file, notebook_name=colab_name, log_history=log_history) return script_entry_point except Exception: return None + @classmethod + def _get_colab_notebook(cls, timeout=30): + # returns tuple (notebook name, raw string notebook) + # None, None if fails + try: + from google.colab import _message + + notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb'] + notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb") + if not notebook_name.endswith(".ipynb"): + notebook_name += ".ipynb" + + # encoding to json + return notebook_name, json.dumps(notebook) + except: # noqa + return None, None + @classmethod def _get_entry_point(cls, repo_root, script_path): repo_root = Path(repo_root).absolute()