mirror of
https://github.com/clearml/clearml
synced 2025-04-22 07:15:57 +00:00
Improve CoLab integration (store entire colab, not history)
This commit is contained in:
parent
89b675c267
commit
16fb59c33f
@ -3,7 +3,7 @@ import sys
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from tempfile import mkstemp, gettempdir
|
||||
from tempfile import gettempdir, mkdtemp
|
||||
|
||||
import attr
|
||||
import logging
|
||||
@ -273,7 +273,7 @@ class _JupyterObserver(object):
|
||||
return get_logger("Repository Detection")
|
||||
|
||||
@classmethod
|
||||
def observer(cls, jupyter_notebook_filename, log_history):
|
||||
def observer(cls, jupyter_notebook_filename, notebook_name=None, log_history=False):
|
||||
if cls._thread is not None:
|
||||
# order of signaling is important!
|
||||
cls._exit_event.set()
|
||||
@ -286,7 +286,7 @@ class _JupyterObserver(object):
|
||||
|
||||
cls._sync_event.clear()
|
||||
cls._exit_event.clear()
|
||||
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
|
||||
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, notebook_name))
|
||||
cls._thread.daemon = True
|
||||
cls._thread.start()
|
||||
|
||||
@ -304,15 +304,25 @@ class _JupyterObserver(object):
|
||||
cls._thread = None
|
||||
|
||||
@classmethod
|
||||
def _daemon(cls, jupyter_notebook_filename):
|
||||
def _daemon(cls, jupyter_notebook_filename, notebook_name=None):
|
||||
from clearml import Task
|
||||
|
||||
# load jupyter notebook package
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters.script import ScriptExporter
|
||||
_script_exporter = ScriptExporter()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters import PythonExporter
|
||||
_script_exporter = PythonExporter()
|
||||
except Exception:
|
||||
_script_exporter = None
|
||||
|
||||
if _script_exporter is None:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters.script import ScriptExporter
|
||||
_script_exporter = ScriptExporter()
|
||||
|
||||
except Exception as ex:
|
||||
cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex))
|
||||
_script_exporter = None
|
||||
@ -341,9 +351,15 @@ class _JupyterObserver(object):
|
||||
local_jupyter_filename = jupyter_notebook_filename
|
||||
else:
|
||||
notebook = None
|
||||
fd, local_jupyter_filename = mkstemp(suffix='.ipynb')
|
||||
os.close(fd)
|
||||
folder = mkdtemp(suffix='.notebook')
|
||||
if notebook_name.endswith(".py"):
|
||||
notebook_name = notebook_name.replace(".py", ".ipynb")
|
||||
if not notebook_name.endswith(".ipynb"):
|
||||
notebook_name += ".ipynb"
|
||||
local_jupyter_filename = Path(folder) / notebook_name
|
||||
|
||||
last_update_ts = None
|
||||
last_colab_hash = None
|
||||
counter = 0
|
||||
prev_script_hash = None
|
||||
|
||||
@ -357,7 +373,7 @@ class _JupyterObserver(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import re
|
||||
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)')
|
||||
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\([ \t]*\)')
|
||||
replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(')
|
||||
except Exception:
|
||||
replace_ipython_pattern = None
|
||||
@ -388,6 +404,20 @@ class _JupyterObserver(object):
|
||||
if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0:
|
||||
continue
|
||||
last_update_ts = notebook.stat().st_mtime
|
||||
elif notebook_name:
|
||||
# this is a colab, let's try to get the notebook
|
||||
# noinspection PyProtectedMember
|
||||
colab_name, colab_notebook = ScriptInfo._get_colab_notebook()
|
||||
if colab_notebook:
|
||||
current_colab_hash = hash(colab_notebook)
|
||||
if current_colab_hash == last_colab_hash:
|
||||
continue
|
||||
last_colab_hash = current_colab_hash
|
||||
with open(local_jupyter_filename.as_posix(), "wt") as f:
|
||||
f.write(colab_notebook)
|
||||
else:
|
||||
# something went wrong we will try again later
|
||||
continue
|
||||
else:
|
||||
# serialize notebook to a temp file
|
||||
if cls._jupyter_history_logger:
|
||||
@ -449,13 +479,6 @@ class _JupyterObserver(object):
|
||||
if prev_script_hash and prev_script_hash == current_script_hash:
|
||||
continue
|
||||
|
||||
# remove ipython direct access from the script code
|
||||
# we will not be able to run them anyhow
|
||||
if replace_ipython_pattern:
|
||||
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
|
||||
if replace_ipython_display_pattern:
|
||||
script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code)
|
||||
|
||||
requirements_txt = ''
|
||||
conda_requirements = ''
|
||||
# parse jupyter python script and prepare pip requirements (pigar)
|
||||
@ -490,6 +513,14 @@ class _JupyterObserver(object):
|
||||
reqs.add(pkg_name, version, fmodules[name])
|
||||
requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs)
|
||||
|
||||
# remove ipython direct access from the script code
|
||||
# we will not be able to run them anyhow
|
||||
# probably should be better dealt with, because multi line will break it
|
||||
if replace_ipython_pattern:
|
||||
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
|
||||
if replace_ipython_display_pattern:
|
||||
script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code)
|
||||
|
||||
# update script
|
||||
prev_script_hash = current_script_hash
|
||||
data_script = task.data.script
|
||||
@ -515,14 +546,15 @@ class ScriptInfo(object):
|
||||
return get_logger("Repository Detection")
|
||||
|
||||
@classmethod
|
||||
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False):
|
||||
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, notebook_name=None, log_history=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if 'IPython' in sys.modules:
|
||||
# noinspection PyPackageRequirements
|
||||
from IPython import get_ipython
|
||||
if get_ipython():
|
||||
_JupyterObserver.observer(jupyter_notebook_filename, log_history)
|
||||
_JupyterObserver.observer(
|
||||
jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history)
|
||||
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
|
||||
if log_history:
|
||||
get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync)
|
||||
@ -662,6 +694,8 @@ class ScriptInfo(object):
|
||||
break
|
||||
|
||||
is_google_colab = False
|
||||
log_history = False
|
||||
colab_name = None
|
||||
# check if this is google.colab, then there is no local file
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -673,8 +707,15 @@ class ScriptInfo(object):
|
||||
pass
|
||||
|
||||
if is_google_colab:
|
||||
# check if we can get the notebook
|
||||
colab_name, colab_notebook = cls._get_colab_notebook()
|
||||
if colab_name is not None:
|
||||
notebook_name = colab_name
|
||||
log_history = False
|
||||
|
||||
script_entry_point = str(notebook_name or 'notebook').replace(
|
||||
'>', '_').replace('<', '_').replace('.ipynb', '.py')
|
||||
|
||||
if not script_entry_point.lower().endswith('.py'):
|
||||
script_entry_point += '.py'
|
||||
local_ipynb_file = None
|
||||
@ -724,12 +765,29 @@ class ScriptInfo(object):
|
||||
|
||||
# install the post store hook,
|
||||
# notice that if we do not have a local file we serialize/write every time the entire notebook
|
||||
cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab)
|
||||
cls._jupyter_install_post_store_hook(local_ipynb_file, notebook_name=colab_name, log_history=log_history)
|
||||
|
||||
return script_entry_point
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_colab_notebook(cls, timeout=30):
|
||||
# returns tuple (notebook name, raw string notebook)
|
||||
# None, None if fails
|
||||
try:
|
||||
from google.colab import _message
|
||||
|
||||
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
|
||||
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
|
||||
if not notebook_name.endswith(".ipynb"):
|
||||
notebook_name += ".ipynb"
|
||||
|
||||
# encoding to json
|
||||
return notebook_name, json.dumps(notebook)
|
||||
except: # noqa
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
def _get_entry_point(cls, repo_root, script_path):
|
||||
repo_root = Path(repo_root).absolute()
|
||||
|
Loading…
Reference in New Issue
Block a user