diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index c1632eac..65684c53 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -1,5 +1,6 @@ import os import sys +from tempfile import mkstemp import attr import collections @@ -119,20 +120,28 @@ class ScriptRequirements(object): class _JupyterObserver(object): _thread = None _exit_event = Event() + _sync_event = Event() _sample_frequency = 30. _first_sample_frequency = 3. @classmethod def observer(cls, jupyter_notebook_filename): if cls._thread is not None: + # order of signaling is important! cls._exit_event.set() + cls._sync_event.set() cls._thread.join() + cls._sync_event.clear() cls._exit_event.clear() cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) cls._thread.daemon = True cls._thread.start() + @classmethod + def signal_sync(cls, *_): + cls._sync_event.set() + @classmethod def _daemon(cls, jupyter_notebook_filename): from trains import Task @@ -153,28 +162,59 @@ class _JupyterObserver(object): logger.setLevel(logging.WARNING) except Exception: file_import_modules = None - # main observer loop - notebook = Path(jupyter_notebook_filename) + # load IPython + # noinspection PyBroadException + try: + from IPython import get_ipython + except Exception: + # should not happen + get_ipython = None + + # setup local notebook files + if jupyter_notebook_filename: + notebook = Path(jupyter_notebook_filename) + local_jupyter_filename = jupyter_notebook_filename + else: + notebook = None + fd, local_jupyter_filename = mkstemp(suffix='.ipynb') + os.close(fd) last_update_ts = None counter = 0 prev_script_hash = None + # main observer loop while True: - if cls._exit_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency): + # wait for timeout or sync event + cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency) + # check if we need to exit + if cls._exit_event.wait(timeout=0.): return + cls._sync_event.clear() counter += 1 # noinspection PyBroadException try: - if not notebook.exists(): - continue - # check if notebook changed - if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: - continue - last_update_ts = notebook.stat().st_mtime + # if there is no task connected, do nothing task = Task.current_task() if not task: continue + + # if we have a local file: + if notebook: + if not notebook.exists(): + continue + # check if notebook changed + if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: + continue + last_update_ts = notebook.stat().st_mtime + else: + # serialize notebook to a temp file + # noinspection PyBroadException + try: + get_ipython().run_line_magic('notebook', local_jupyter_filename) + except Exception as ex: + continue + # get notebook python script - script_code, resources = _script_exporter.from_filename(jupyter_notebook_filename) + script_code, resources = _script_exporter.from_filename(local_jupyter_filename) current_script_hash = hash(script_code) if prev_script_hash and prev_script_hash == current_script_hash: continue @@ -198,8 +238,7 @@ class _JupyterObserver(object): data_script.requirements = {'pip': requirements_txt} task._update_script(script=data_script) # update requirements - if requirements_txt: - task._update_requirements(requirements=requirements_txt) + task._update_requirements(requirements=requirements_txt) except Exception: pass @@ -217,12 +256,15 @@ class ScriptInfo(object): from IPython import get_ipython if get_ipython(): _JupyterObserver.observer(jupyter_notebook_filename) + get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) except Exception: pass @classmethod def _get_jupyter_notebook_filename(cls): - if not sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'): + if not (sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or + sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \ + or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'): return None # we can safely assume that we can import the notebook package here @@ -244,23 +286,45 @@ class ScriptInfo(object): cur_notebook = n break - notebook_path = cur_notebook['notebook']['path'] - # always slash, because this is from uri (so never backslash not even oon windows) - entry_point_filename = notebook_path.split('/')[-1] + notebook_path = cur_notebook['notebook'].get('path', '') + notebook_name = cur_notebook['notebook'].get('name', '') - # now we should try to find the actual file - entry_point = (Path.cwd() / entry_point_filename).absolute() - if not entry_point.is_file(): - entry_point = (Path.cwd() / notebook_path).absolute() + is_google_colab = False + # check if this is google.colab, then there is no local file + # noinspection PyBroadException + try: + from IPython import get_ipython + if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded: + is_google_colab = True + except Exception: + pass - # install the post store hook, so always have a synced file in the system - cls._jupyter_install_post_store_hook(entry_point.as_posix()) + if is_google_colab: + script_entry_point = notebook_name + local_ipynb_file = None + else: + # always slash, because this is from uri (so never backslash not even oon windows) + entry_point_filename = notebook_path.split('/')[-1] - # now replace the .ipynb with .py - # we assume we will have that file available with the Jupyter notebook plugin - entry_point = entry_point.with_suffix('.py') + # now we should try to find the actual file + entry_point = (Path.cwd() / entry_point_filename).absolute() + if not entry_point.is_file(): + entry_point = (Path.cwd() / notebook_path).absolute() - return entry_point.as_posix() + # get local ipynb for observer + local_ipynb_file = entry_point.as_posix() + + # now replace the .ipynb with .py + # we assume we will have that file available with the Jupyter notebook plugin + entry_point = entry_point.with_suffix('.py') + + script_entry_point = entry_point.as_posix() + + # install the post store hook, + # notice that if we do not have a local file we serialize/write every time the entire notebook + cls._jupyter_install_post_store_hook(local_ipynb_file) + + return script_entry_point except Exception: return None