From 8a5f6b7d029779293eaa5899d76736e61fb0c3a6 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 13 Jun 2020 22:12:28 +0300 Subject: [PATCH] Fix Google CoLab code/package detection --- .../backend_interface/task/repo/scriptinfo.py | 243 ++++++++++++++++-- 1 file changed, 221 insertions(+), 22 deletions(-) diff --git a/trains/backend_interface/task/repo/scriptinfo.py b/trains/backend_interface/task/repo/scriptinfo.py index 4b6d3a2b..0b2ab265 100644 --- a/trains/backend_interface/task/repo/scriptinfo.py +++ b/trains/backend_interface/task/repo/scriptinfo.py @@ -1,10 +1,10 @@ import os import sys from copy import copy +from functools import partial from tempfile import mkstemp import attr -import collections import logging import json from furl import furl @@ -30,6 +30,7 @@ class ScriptRequirements(object): self._root_folder = root_folder def get_requirements(self, entry_point_filename=None): + # noinspection PyBroadException try: from ....utilities.pigar.reqs import get_installed_pkgs_detail from ....utilities.pigar.__main__ import GenerateReqs @@ -48,18 +49,21 @@ class ScriptRequirements(object): # hack: forcefully insert storage modules if we have them # noinspection PyBroadException try: + # noinspection PyPackageRequirements,PyUnresolvedReferences import boto3 modules.add('boto3', 'trains.storage', 0) except Exception: pass # noinspection PyBroadException try: + # noinspection PyPackageRequirements,PyUnresolvedReferences from google.cloud import storage modules.add('google_cloud_storage', 'trains.storage', 0) except Exception: pass # noinspection PyBroadException try: + # noinspection PyPackageRequirements,PyUnresolvedReferences from azure.storage.blob import ContentSettings modules.add('azure_storage_blob', 'trains.storage', 0) except Exception: @@ -77,7 +81,9 @@ class ScriptRequirements(object): # noinspection PyBroadException try: # see if this version of torch support tensorboard + # noinspection PyPackageRequirements,PyUnresolvedReferences import torch.utils.tensorboard + # noinspection PyPackageRequirements,PyUnresolvedReferences import tensorboard modules.add('tensorboard', 'torch', 0) except Exception: @@ -91,6 +97,7 @@ class ScriptRequirements(object): # noinspection PyBroadException try: from ..task import Task + # noinspection PyProtectedMember for package, version in Task._force_requirements.items(): modules.add(package, 'trains', 0) except Exception: @@ -101,6 +108,7 @@ class ScriptRequirements(object): @staticmethod def create_requirements_txt(reqs, local_pks=None): # write requirements.txt + # noinspection PyBroadException try: conda_requirements = '' conda_prefix = os.environ.get('CONDA_PREFIX') @@ -120,15 +128,16 @@ class ScriptRequirements(object): if name == 'pytorch': name = 'torch' k, v = reqs_lower.get(name, (None, None)) - if k: + if k and v is not None: conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version) - except: + except Exception: conda_requirements = '' # add forced requirements: # noinspection PyBroadException try: from ..task import Task + # noinspection PyProtectedMember forced_packages = copy(Task._force_requirements) except Exception: forced_packages = {} @@ -198,15 +207,20 @@ class _JupyterObserver(object): _sync_event = Event() _sample_frequency = 30. _first_sample_frequency = 3. + _jupyter_history_logger = None @classmethod - def observer(cls, jupyter_notebook_filename): + def observer(cls, jupyter_notebook_filename, log_history): if cls._thread is not None: # order of signaling is important! cls._exit_event.set() cls._sync_event.set() cls._thread.join() + if log_history and cls._jupyter_history_logger is None: + cls._jupyter_history_logger = _JupyterHistoryLogger() + cls._jupyter_history_logger.hook() + cls._sync_event.clear() cls._exit_event.clear() cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) @@ -214,7 +228,7 @@ class _JupyterObserver(object): cls._thread.start() @classmethod - def signal_sync(cls, *_): + def signal_sync(cls, *_, **__): cls._sync_event.set() @classmethod @@ -233,6 +247,7 @@ class _JupyterObserver(object): # load jupyter notebook package # noinspection PyBroadException try: + # noinspection PyPackageRequirements from nbconvert.exporters.script import ScriptExporter _script_exporter = ScriptExporter() except Exception: @@ -249,6 +264,7 @@ class _JupyterObserver(object): # load IPython # noinspection PyBroadException try: + # noinspection PyPackageRequirements from IPython import get_ipython except Exception: # should not happen @@ -266,16 +282,18 @@ class _JupyterObserver(object): counter = 0 prev_script_hash = None + # noinspection PyBroadException try: from ....version import __version__ our_module = cls.__module__.split('.')[0], __version__ - except: + except Exception: our_module = None + # noinspection PyBroadException try: import re - replace_ipython_pattern = re.compile('\\n([ \\t]*)get_ipython\(\)') - except: + replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)') + except Exception: replace_ipython_pattern = None # main observer loop, check if we need to exit @@ -292,6 +310,9 @@ class _JupyterObserver(object): if not task: continue + script_code = None + fmodules = None + current_cell = None # if we have a local file: if notebook: if not notebook.exists(): @@ -302,35 +323,67 @@ class _JupyterObserver(object): last_update_ts = notebook.stat().st_mtime else: # serialize notebook to a temp file - # noinspection PyBroadException - try: - get_ipython().run_line_magic('notebook', local_jupyter_filename) - except Exception as ex: - continue + if cls._jupyter_history_logger: + script_code, current_cell = cls._jupyter_history_logger.history_to_str() + else: + # noinspection PyBroadException + try: + # noinspection PyBroadException + try: + os.unlink(local_jupyter_filename) + except Exception: + pass + get_ipython().run_line_magic('history', '-t -f {}'.format(local_jupyter_filename)) + with open(local_jupyter_filename, 'r') as f: + script_code = f.read() + # load the modules + from ....utilities.pigar.modules import ImportedModules + fmodules = ImportedModules() + for nm in set([str(m).split('.')[0] for m in sys.modules]): + fmodules.add(nm, 'notebook', 0) + except Exception: + continue # get notebook python script - script_code, resources = _script_exporter.from_filename(local_jupyter_filename) - current_script_hash = hash(script_code) + if script_code is None: + script_code, _ = _script_exporter.from_filename(local_jupyter_filename) + + current_script_hash = hash(script_code + (current_cell or '')) if prev_script_hash and prev_script_hash == current_script_hash: continue # remove ipython direct access from the script code # we will not be able to run them anyhow if replace_ipython_pattern: - script_code = replace_ipython_pattern.sub('\n# \g<1>get_ipython()', script_code) + script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code) requirements_txt = '' conda_requirements = '' # parse jupyter python script and prepare pip requirements (pigar) # if backend supports requirements if file_import_modules and Session.check_min_api_version('2.2'): - fmodules, _ = file_import_modules(notebook.parts[-1], script_code) + if fmodules is None: + fmodules, _ = file_import_modules( + notebook.parts[-1] if notebook else 'notebook', script_code) + if current_cell: + cell_fmodules, _ = file_import_modules( + notebook.parts[-1] if notebook else 'notebook', current_cell) + # noinspection PyBroadException + try: + fmodules |= cell_fmodules + except Exception: + pass + # add current cell to the script + if current_cell: + script_code += '\n' + current_cell fmodules = ScriptRequirements.add_trains_used_packages(fmodules) + # noinspection PyUnboundLocalVariable installed_pkgs = get_installed_pkgs_detail() # make sure we are in installed packages if our_module and (our_module[0] not in installed_pkgs): installed_pkgs[our_module[0]] = our_module + # noinspection PyUnboundLocalVariable reqs = ReqsModules() for name in fmodules: if name in installed_pkgs: @@ -343,8 +396,10 @@ class _JupyterObserver(object): data_script = task.data.script data_script.diff = script_code data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements} + # noinspection PyProtectedMember task._update_script(script=data_script) # update requirements + # noinspection PyProtectedMember task._update_requirements(requirements=requirements_txt) except Exception: pass @@ -356,14 +411,17 @@ class ScriptInfo(object): """ Script info detection plugins, in order of priority """ @classmethod - def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename): + def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False): # noinspection PyBroadException try: if 'IPython' in sys.modules: + # noinspection PyPackageRequirements from IPython import get_ipython if get_ipython(): - _JupyterObserver.observer(jupyter_notebook_filename) + _JupyterObserver.observer(jupyter_notebook_filename, log_history) get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) + if log_history: + get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync) except Exception: pass @@ -377,22 +435,26 @@ class ScriptInfo(object): # we can safely assume that we can import the notebook package here # noinspection PyBroadException try: + # noinspection PyPackageRequirements from notebook.notebookapp import list_running_servers import requests current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '') + # noinspection PyBroadException try: server_info = next(list_running_servers()) except Exception: # on some jupyter notebook versions this function can crash on parsing the json file, # we will parse it manually here + # noinspection PyPackageRequirements import ipykernel from glob import glob import json for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), 'nbserver-*.json')): + # noinspection PyBroadException try: with open(f, 'r') as json_data: server_info = json.load(json_data) - except: + except Exception: server_info = None if server_info: break @@ -403,6 +465,7 @@ class ScriptInfo(object): except requests.exceptions.SSLError: # disable SSL check warning from urllib3.exceptions import InsecureRequestWarning + # noinspection PyUnresolvedReferences requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) # fire request r = requests.get( @@ -428,6 +491,7 @@ class ScriptInfo(object): # check if this is google.colab, then there is no local file # noinspection PyBroadException try: + # noinspection PyPackageRequirements from IPython import get_ipython if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded: is_google_colab = True @@ -435,7 +499,10 @@ class ScriptInfo(object): pass if is_google_colab: - script_entry_point = notebook_name + script_entry_point = str(notebook_name or 'notebook').replace( + '>', '_').replace('<', '_').replace('.ipynb', '.py') + if not script_entry_point.lower().endswith('.py'): + script_entry_point += '.py' local_ipynb_file = None else: # always slash, because this is from uri (so never backslash not even oon windows) @@ -457,7 +524,7 @@ class ScriptInfo(object): # install the post store hook, # notice that if we do not have a local file we serialize/write every time the entire notebook - cls._jupyter_install_post_store_hook(local_ipynb_file) + cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab) return script_entry_point except Exception: @@ -641,3 +708,135 @@ class ScriptInfo(object): class ScriptInfoResult(object): script = attr.ib(default=None) warning_messages = attr.ib(factory=list) + + +class _JupyterHistoryLogger(object): + _reg_replace_ipython = r'\n([ \t]*)get_ipython\(\)' + _reg_replace_magic = r'\n([ \t]*)%' + _reg_replace_bang = r'\n([ \t]*)!' + + def __init__(self): + self._exception_raised = False + self._cells_code = {} + self._counter = 0 + self._ip = None + self._current_cell = None + # noinspection PyBroadException + try: + import re + self._replace_ipython_pattern = re.compile(self._reg_replace_ipython) + self._replace_magic_pattern = re.compile(self._reg_replace_magic) + self._replace_bang_pattern = re.compile(self._reg_replace_bang) + except Exception: + self._replace_ipython_pattern = None + self._replace_magic_pattern = None + self._replace_bang_pattern = None + + def hook(self, ip=None): + if not ip: + # noinspection PyBroadException + try: + # noinspection PyPackageRequirements + from IPython import get_ipython + except Exception: + return + self._ip = get_ipython() + else: + self._ip = ip + + # noinspection PyBroadException + try: + # if this is colab, the callbacks do not contain the raw_cell content, so we have to patch it + if 'google.colab' in self._ip.extension_manager.loaded: + self._ip._org_run_cell = self._ip.run_cell + self._ip.run_cell = partial(self._patched_run_cell, self._ip) + except Exception as ex: + pass + + # start with the current history + self._initialize_history() + self._ip.events.register('post_run_cell', self._post_cell_callback) + self._ip.events.register('pre_run_cell', self._pre_cell_callback) + self._ip.set_custom_exc((Exception,), self._exception_callback) + + def _patched_run_cell(self, shell, *args, **kwargs): + # noinspection PyBroadException + try: + raw_cell = kwargs.get('raw_cell') or args[0] + self._current_cell = raw_cell + except Exception: + pass + # noinspection PyProtectedMember + return shell._org_run_cell(*args, **kwargs) + + def history(self, filename): + with open(filename, 'wt') as f: + for k, v in sorted(self._cells_code.items(), key=lambda p: p[0]): + f.write(v) + + def history_to_str(self): + # return a pair: (history as str, current cell if we are in still in cell execution otherwise None) + return '\n'.join(v for k, v in sorted(self._cells_code.items(), key=lambda p: p[0])), self._current_cell + + # noinspection PyUnusedLocal + def _exception_callback(self, shell, etype, value, tb, tb_offset=None): + self._exception_raised = True + return shell.showtraceback() + + def _pre_cell_callback(self, *args, **_): + # noinspection PyBroadException + try: + if args: + self._current_cell = args[0].raw_cell + # we might have this value from somewhere else + if self._current_cell: + self._current_cell = self._conform_code(self._current_cell, replace_magic_bang=True) + except Exception: + pass + + def _post_cell_callback(self, *_, **__): + # noinspection PyBroadException + try: + self._current_cell = None + if self._exception_raised: + # do nothing + self._exception_raised = False + return + + self._exception_raised = False + # add the cell history + # noinspection PyBroadException + try: + cell_code = '\n' + self._ip.history_manager.input_hist_parsed[-1] + except Exception: + return + + # fix magic / bang in code + cell_code = self._conform_code(cell_code) + + self._cells_code[self._counter] = cell_code + self._counter += 1 + except Exception: + pass + + def _initialize_history(self): + # only once + if -1 in self._cells_code: + return + # noinspection PyBroadException + try: + cell_code = '\n' + '\n'.join(self._ip.history_manager.input_hist_parsed[:-1]) + except Exception: + return + + cell_code = self._conform_code(cell_code) + self._cells_code[-1] = cell_code + + def _conform_code(self, cell_code, replace_magic_bang=False): + # fix magic / bang in code + if self._replace_ipython_pattern: + cell_code = self._replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', cell_code) + if replace_magic_bang and self._replace_magic_pattern and self._replace_bang_pattern: + cell_code = self._replace_magic_pattern.sub(r'\n# \g<1>%', cell_code) + cell_code = self._replace_bang_pattern.sub(r'\n# \g<1>!', cell_code) + return cell_code