mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Improve CoLab integration (store entire colab, not history)
This commit is contained in:
		
							parent
							
								
									89b675c267
								
							
						
					
					
						commit
						16fb59c33f
					
				| @ -3,7 +3,7 @@ import sys | ||||
| from copy import copy | ||||
| from datetime import datetime | ||||
| from functools import partial | ||||
| from tempfile import mkstemp, gettempdir | ||||
| from tempfile import gettempdir, mkdtemp | ||||
| 
 | ||||
| import attr | ||||
| import logging | ||||
| @ -273,7 +273,7 @@ class _JupyterObserver(object): | ||||
|         return get_logger("Repository Detection") | ||||
| 
 | ||||
|     @classmethod | ||||
|     def observer(cls, jupyter_notebook_filename, log_history): | ||||
|     def observer(cls, jupyter_notebook_filename, notebook_name=None, log_history=False): | ||||
|         if cls._thread is not None: | ||||
|             # order of signaling is important! | ||||
|             cls._exit_event.set() | ||||
| @ -286,7 +286,7 @@ class _JupyterObserver(object): | ||||
| 
 | ||||
|         cls._sync_event.clear() | ||||
|         cls._exit_event.clear() | ||||
|         cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) | ||||
|         cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, notebook_name)) | ||||
|         cls._thread.daemon = True | ||||
|         cls._thread.start() | ||||
| 
 | ||||
| @ -304,15 +304,25 @@ class _JupyterObserver(object): | ||||
|         cls._thread = None | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _daemon(cls, jupyter_notebook_filename): | ||||
|     def _daemon(cls, jupyter_notebook_filename, notebook_name=None): | ||||
|         from clearml import Task | ||||
| 
 | ||||
|         # load jupyter notebook package | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             # noinspection PyPackageRequirements | ||||
|             from nbconvert.exporters.script import ScriptExporter | ||||
|             _script_exporter = ScriptExporter() | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 # noinspection PyPackageRequirements | ||||
|                 from nbconvert.exporters import PythonExporter | ||||
|                 _script_exporter = PythonExporter() | ||||
|             except Exception: | ||||
|                 _script_exporter = None | ||||
| 
 | ||||
|             if _script_exporter is None: | ||||
|                 # noinspection PyPackageRequirements | ||||
|                 from nbconvert.exporters.script import ScriptExporter | ||||
|                 _script_exporter = ScriptExporter() | ||||
| 
 | ||||
|         except Exception as ex: | ||||
|             cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex)) | ||||
|             _script_exporter = None | ||||
| @ -341,9 +351,15 @@ class _JupyterObserver(object): | ||||
|             local_jupyter_filename = jupyter_notebook_filename | ||||
|         else: | ||||
|             notebook = None | ||||
|             fd, local_jupyter_filename = mkstemp(suffix='.ipynb') | ||||
|             os.close(fd) | ||||
|             folder = mkdtemp(suffix='.notebook') | ||||
|             if notebook_name.endswith(".py"): | ||||
|                 notebook_name = notebook_name.replace(".py", ".ipynb") | ||||
|             if not notebook_name.endswith(".ipynb"): | ||||
|                 notebook_name += ".ipynb" | ||||
|             local_jupyter_filename = Path(folder) / notebook_name | ||||
| 
 | ||||
|         last_update_ts = None | ||||
|         last_colab_hash = None | ||||
|         counter = 0 | ||||
|         prev_script_hash = None | ||||
| 
 | ||||
| @ -357,7 +373,7 @@ class _JupyterObserver(object): | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             import re | ||||
|             replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)') | ||||
|             replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\([ \t]*\)') | ||||
|             replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(') | ||||
|         except Exception: | ||||
|             replace_ipython_pattern = None | ||||
| @ -388,6 +404,20 @@ class _JupyterObserver(object): | ||||
|                     if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: | ||||
|                         continue | ||||
|                     last_update_ts = notebook.stat().st_mtime | ||||
|                 elif notebook_name: | ||||
|                     # this is a colab, let's try to get the notebook | ||||
|                     # noinspection PyProtectedMember | ||||
|                     colab_name, colab_notebook = ScriptInfo._get_colab_notebook() | ||||
|                     if colab_notebook: | ||||
|                         current_colab_hash = hash(colab_notebook) | ||||
|                         if current_colab_hash == last_colab_hash: | ||||
|                             continue | ||||
|                         last_colab_hash = current_colab_hash | ||||
|                         with open(local_jupyter_filename.as_posix(), "wt") as f: | ||||
|                             f.write(colab_notebook) | ||||
|                     else: | ||||
|                         # something went wrong we will try again later | ||||
|                         continue | ||||
|                 else: | ||||
|                     # serialize notebook to a temp file | ||||
|                     if cls._jupyter_history_logger: | ||||
| @ -449,13 +479,6 @@ class _JupyterObserver(object): | ||||
|                     if prev_script_hash and prev_script_hash == current_script_hash: | ||||
|                         continue | ||||
| 
 | ||||
|                     # remove ipython direct access from the script code | ||||
|                     # we will not be able to run them anyhow | ||||
|                     if replace_ipython_pattern: | ||||
|                         script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code) | ||||
|                     if replace_ipython_display_pattern: | ||||
|                         script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code) | ||||
| 
 | ||||
|                     requirements_txt = '' | ||||
|                     conda_requirements = '' | ||||
|                     # parse jupyter python script and prepare pip requirements (pigar) | ||||
| @ -490,6 +513,14 @@ class _JupyterObserver(object): | ||||
|                                 reqs.add(pkg_name, version, fmodules[name]) | ||||
|                         requirements_txt, conda_requirements = ScriptRequirements.create_requirements_txt(reqs) | ||||
| 
 | ||||
|                     # remove ipython direct access from the script code | ||||
|                     # we will not be able to run them anyhow | ||||
|                     # probably should be better dealt with, because multi line will break it | ||||
|                     if replace_ipython_pattern: | ||||
|                         script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code) | ||||
|                     if replace_ipython_display_pattern: | ||||
|                         script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code) | ||||
| 
 | ||||
|                 # update script | ||||
|                 prev_script_hash = current_script_hash | ||||
|                 data_script = task.data.script | ||||
| @ -515,14 +546,15 @@ class ScriptInfo(object): | ||||
|         return get_logger("Repository Detection") | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False): | ||||
|     def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, notebook_name=None, log_history=False): | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             if 'IPython' in sys.modules: | ||||
|                 # noinspection PyPackageRequirements | ||||
|                 from IPython import get_ipython | ||||
|                 if get_ipython(): | ||||
|                     _JupyterObserver.observer(jupyter_notebook_filename, log_history) | ||||
|                     _JupyterObserver.observer( | ||||
|                         jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history) | ||||
|                     get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) | ||||
|                     if log_history: | ||||
|                         get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync) | ||||
| @ -662,6 +694,8 @@ class ScriptInfo(object): | ||||
|                     break | ||||
| 
 | ||||
|             is_google_colab = False | ||||
|             log_history = False | ||||
|             colab_name = None | ||||
|             # check if this is google.colab, then there is no local file | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
| @ -673,8 +707,15 @@ class ScriptInfo(object): | ||||
|                 pass | ||||
| 
 | ||||
|             if is_google_colab: | ||||
|                 # check if we can get the notebook | ||||
|                 colab_name, colab_notebook = cls._get_colab_notebook() | ||||
|                 if colab_name is not None: | ||||
|                     notebook_name = colab_name | ||||
|                     log_history = False | ||||
| 
 | ||||
|                 script_entry_point = str(notebook_name or 'notebook').replace( | ||||
|                     '>', '_').replace('<', '_').replace('.ipynb', '.py') | ||||
| 
 | ||||
|                 if not script_entry_point.lower().endswith('.py'): | ||||
|                     script_entry_point += '.py' | ||||
|                 local_ipynb_file = None | ||||
| @ -724,12 +765,29 @@ class ScriptInfo(object): | ||||
| 
 | ||||
|             # install the post store hook, | ||||
|             # notice that if we do not have a local file we serialize/write every time the entire notebook | ||||
|             cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab) | ||||
|             cls._jupyter_install_post_store_hook(local_ipynb_file, notebook_name=colab_name, log_history=log_history) | ||||
| 
 | ||||
|             return script_entry_point | ||||
|         except Exception: | ||||
|             return None | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_colab_notebook(cls, timeout=30): | ||||
|         # returns tuple (notebook name, raw string notebook) | ||||
|         # None, None if fails | ||||
|         try: | ||||
|             from google.colab import _message | ||||
| 
 | ||||
|             notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb'] | ||||
|             notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb") | ||||
|             if not notebook_name.endswith(".ipynb"): | ||||
|                 notebook_name += ".ipynb" | ||||
| 
 | ||||
|             # encoding to json | ||||
|             return notebook_name, json.dumps(notebook) | ||||
|         except:  # noqa | ||||
|             return None, None | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_entry_point(cls, repo_root, script_path): | ||||
|         repo_root = Path(repo_root).absolute() | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai