mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix Google CoLab code/package detection
This commit is contained in:
parent
aa61fa3f06
commit
8a5f6b7d02
@ -1,10 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
from functools import partial
|
||||||
from tempfile import mkstemp
|
from tempfile import mkstemp
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import collections
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from furl import furl
|
from furl import furl
|
||||||
@ -30,6 +30,7 @@ class ScriptRequirements(object):
|
|||||||
self._root_folder = root_folder
|
self._root_folder = root_folder
|
||||||
|
|
||||||
def get_requirements(self, entry_point_filename=None):
|
def get_requirements(self, entry_point_filename=None):
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
from ....utilities.pigar.reqs import get_installed_pkgs_detail
|
from ....utilities.pigar.reqs import get_installed_pkgs_detail
|
||||||
from ....utilities.pigar.__main__ import GenerateReqs
|
from ....utilities.pigar.__main__ import GenerateReqs
|
||||||
@ -48,18 +49,21 @@ class ScriptRequirements(object):
|
|||||||
# hack: forcefully insert storage modules if we have them
|
# hack: forcefully insert storage modules if we have them
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements,PyUnresolvedReferences
|
||||||
import boto3
|
import boto3
|
||||||
modules.add('boto3', 'trains.storage', 0)
|
modules.add('boto3', 'trains.storage', 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements,PyUnresolvedReferences
|
||||||
from google.cloud import storage
|
from google.cloud import storage
|
||||||
modules.add('google_cloud_storage', 'trains.storage', 0)
|
modules.add('google_cloud_storage', 'trains.storage', 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements,PyUnresolvedReferences
|
||||||
from azure.storage.blob import ContentSettings
|
from azure.storage.blob import ContentSettings
|
||||||
modules.add('azure_storage_blob', 'trains.storage', 0)
|
modules.add('azure_storage_blob', 'trains.storage', 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -77,7 +81,9 @@ class ScriptRequirements(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# see if this version of torch support tensorboard
|
# see if this version of torch support tensorboard
|
||||||
|
# noinspection PyPackageRequirements,PyUnresolvedReferences
|
||||||
import torch.utils.tensorboard
|
import torch.utils.tensorboard
|
||||||
|
# noinspection PyPackageRequirements,PyUnresolvedReferences
|
||||||
import tensorboard
|
import tensorboard
|
||||||
modules.add('tensorboard', 'torch', 0)
|
modules.add('tensorboard', 'torch', 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -91,6 +97,7 @@ class ScriptRequirements(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
from ..task import Task
|
from ..task import Task
|
||||||
|
# noinspection PyProtectedMember
|
||||||
for package, version in Task._force_requirements.items():
|
for package, version in Task._force_requirements.items():
|
||||||
modules.add(package, 'trains', 0)
|
modules.add(package, 'trains', 0)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -101,6 +108,7 @@ class ScriptRequirements(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_requirements_txt(reqs, local_pks=None):
|
def create_requirements_txt(reqs, local_pks=None):
|
||||||
# write requirements.txt
|
# write requirements.txt
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
conda_requirements = ''
|
conda_requirements = ''
|
||||||
conda_prefix = os.environ.get('CONDA_PREFIX')
|
conda_prefix = os.environ.get('CONDA_PREFIX')
|
||||||
@ -120,15 +128,16 @@ class ScriptRequirements(object):
|
|||||||
if name == 'pytorch':
|
if name == 'pytorch':
|
||||||
name = 'torch'
|
name = 'torch'
|
||||||
k, v = reqs_lower.get(name, (None, None))
|
k, v = reqs_lower.get(name, (None, None))
|
||||||
if k:
|
if k and v is not None:
|
||||||
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
|
conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version)
|
||||||
except:
|
except Exception:
|
||||||
conda_requirements = ''
|
conda_requirements = ''
|
||||||
|
|
||||||
# add forced requirements:
|
# add forced requirements:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
from ..task import Task
|
from ..task import Task
|
||||||
|
# noinspection PyProtectedMember
|
||||||
forced_packages = copy(Task._force_requirements)
|
forced_packages = copy(Task._force_requirements)
|
||||||
except Exception:
|
except Exception:
|
||||||
forced_packages = {}
|
forced_packages = {}
|
||||||
@ -198,15 +207,20 @@ class _JupyterObserver(object):
|
|||||||
_sync_event = Event()
|
_sync_event = Event()
|
||||||
_sample_frequency = 30.
|
_sample_frequency = 30.
|
||||||
_first_sample_frequency = 3.
|
_first_sample_frequency = 3.
|
||||||
|
_jupyter_history_logger = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def observer(cls, jupyter_notebook_filename):
|
def observer(cls, jupyter_notebook_filename, log_history):
|
||||||
if cls._thread is not None:
|
if cls._thread is not None:
|
||||||
# order of signaling is important!
|
# order of signaling is important!
|
||||||
cls._exit_event.set()
|
cls._exit_event.set()
|
||||||
cls._sync_event.set()
|
cls._sync_event.set()
|
||||||
cls._thread.join()
|
cls._thread.join()
|
||||||
|
|
||||||
|
if log_history and cls._jupyter_history_logger is None:
|
||||||
|
cls._jupyter_history_logger = _JupyterHistoryLogger()
|
||||||
|
cls._jupyter_history_logger.hook()
|
||||||
|
|
||||||
cls._sync_event.clear()
|
cls._sync_event.clear()
|
||||||
cls._exit_event.clear()
|
cls._exit_event.clear()
|
||||||
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
|
cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, ))
|
||||||
@ -214,7 +228,7 @@ class _JupyterObserver(object):
|
|||||||
cls._thread.start()
|
cls._thread.start()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def signal_sync(cls, *_):
|
def signal_sync(cls, *_, **__):
|
||||||
cls._sync_event.set()
|
cls._sync_event.set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -233,6 +247,7 @@ class _JupyterObserver(object):
|
|||||||
# load jupyter notebook package
|
# load jupyter notebook package
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
from nbconvert.exporters.script import ScriptExporter
|
from nbconvert.exporters.script import ScriptExporter
|
||||||
_script_exporter = ScriptExporter()
|
_script_exporter = ScriptExporter()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -249,6 +264,7 @@ class _JupyterObserver(object):
|
|||||||
# load IPython
|
# load IPython
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
from IPython import get_ipython
|
from IPython import get_ipython
|
||||||
except Exception:
|
except Exception:
|
||||||
# should not happen
|
# should not happen
|
||||||
@ -266,16 +282,18 @@ class _JupyterObserver(object):
|
|||||||
counter = 0
|
counter = 0
|
||||||
prev_script_hash = None
|
prev_script_hash = None
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
from ....version import __version__
|
from ....version import __version__
|
||||||
our_module = cls.__module__.split('.')[0], __version__
|
our_module = cls.__module__.split('.')[0], __version__
|
||||||
except:
|
except Exception:
|
||||||
our_module = None
|
our_module = None
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
import re
|
import re
|
||||||
replace_ipython_pattern = re.compile('\\n([ \\t]*)get_ipython\(\)')
|
replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\(\)')
|
||||||
except:
|
except Exception:
|
||||||
replace_ipython_pattern = None
|
replace_ipython_pattern = None
|
||||||
|
|
||||||
# main observer loop, check if we need to exit
|
# main observer loop, check if we need to exit
|
||||||
@ -292,6 +310,9 @@ class _JupyterObserver(object):
|
|||||||
if not task:
|
if not task:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
script_code = None
|
||||||
|
fmodules = None
|
||||||
|
current_cell = None
|
||||||
# if we have a local file:
|
# if we have a local file:
|
||||||
if notebook:
|
if notebook:
|
||||||
if not notebook.exists():
|
if not notebook.exists():
|
||||||
@ -302,35 +323,67 @@ class _JupyterObserver(object):
|
|||||||
last_update_ts = notebook.stat().st_mtime
|
last_update_ts = notebook.stat().st_mtime
|
||||||
else:
|
else:
|
||||||
# serialize notebook to a temp file
|
# serialize notebook to a temp file
|
||||||
|
if cls._jupyter_history_logger:
|
||||||
|
script_code, current_cell = cls._jupyter_history_logger.history_to_str()
|
||||||
|
else:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
get_ipython().run_line_magic('notebook', local_jupyter_filename)
|
# noinspection PyBroadException
|
||||||
except Exception as ex:
|
try:
|
||||||
|
os.unlink(local_jupyter_filename)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
get_ipython().run_line_magic('history', '-t -f {}'.format(local_jupyter_filename))
|
||||||
|
with open(local_jupyter_filename, 'r') as f:
|
||||||
|
script_code = f.read()
|
||||||
|
# load the modules
|
||||||
|
from ....utilities.pigar.modules import ImportedModules
|
||||||
|
fmodules = ImportedModules()
|
||||||
|
for nm in set([str(m).split('.')[0] for m in sys.modules]):
|
||||||
|
fmodules.add(nm, 'notebook', 0)
|
||||||
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# get notebook python script
|
# get notebook python script
|
||||||
script_code, resources = _script_exporter.from_filename(local_jupyter_filename)
|
if script_code is None:
|
||||||
current_script_hash = hash(script_code)
|
script_code, _ = _script_exporter.from_filename(local_jupyter_filename)
|
||||||
|
|
||||||
|
current_script_hash = hash(script_code + (current_cell or ''))
|
||||||
if prev_script_hash and prev_script_hash == current_script_hash:
|
if prev_script_hash and prev_script_hash == current_script_hash:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# remove ipython direct access from the script code
|
# remove ipython direct access from the script code
|
||||||
# we will not be able to run them anyhow
|
# we will not be able to run them anyhow
|
||||||
if replace_ipython_pattern:
|
if replace_ipython_pattern:
|
||||||
script_code = replace_ipython_pattern.sub('\n# \g<1>get_ipython()', script_code)
|
script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code)
|
||||||
|
|
||||||
requirements_txt = ''
|
requirements_txt = ''
|
||||||
conda_requirements = ''
|
conda_requirements = ''
|
||||||
# parse jupyter python script and prepare pip requirements (pigar)
|
# parse jupyter python script and prepare pip requirements (pigar)
|
||||||
# if backend supports requirements
|
# if backend supports requirements
|
||||||
if file_import_modules and Session.check_min_api_version('2.2'):
|
if file_import_modules and Session.check_min_api_version('2.2'):
|
||||||
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
|
if fmodules is None:
|
||||||
|
fmodules, _ = file_import_modules(
|
||||||
|
notebook.parts[-1] if notebook else 'notebook', script_code)
|
||||||
|
if current_cell:
|
||||||
|
cell_fmodules, _ = file_import_modules(
|
||||||
|
notebook.parts[-1] if notebook else 'notebook', current_cell)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
fmodules |= cell_fmodules
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# add current cell to the script
|
||||||
|
if current_cell:
|
||||||
|
script_code += '\n' + current_cell
|
||||||
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
|
fmodules = ScriptRequirements.add_trains_used_packages(fmodules)
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
installed_pkgs = get_installed_pkgs_detail()
|
installed_pkgs = get_installed_pkgs_detail()
|
||||||
# make sure we are in installed packages
|
# make sure we are in installed packages
|
||||||
if our_module and (our_module[0] not in installed_pkgs):
|
if our_module and (our_module[0] not in installed_pkgs):
|
||||||
installed_pkgs[our_module[0]] = our_module
|
installed_pkgs[our_module[0]] = our_module
|
||||||
|
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
reqs = ReqsModules()
|
reqs = ReqsModules()
|
||||||
for name in fmodules:
|
for name in fmodules:
|
||||||
if name in installed_pkgs:
|
if name in installed_pkgs:
|
||||||
@ -343,8 +396,10 @@ class _JupyterObserver(object):
|
|||||||
data_script = task.data.script
|
data_script = task.data.script
|
||||||
data_script.diff = script_code
|
data_script.diff = script_code
|
||||||
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
|
data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements}
|
||||||
|
# noinspection PyProtectedMember
|
||||||
task._update_script(script=data_script)
|
task._update_script(script=data_script)
|
||||||
# update requirements
|
# update requirements
|
||||||
|
# noinspection PyProtectedMember
|
||||||
task._update_requirements(requirements=requirements_txt)
|
task._update_requirements(requirements=requirements_txt)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@ -356,14 +411,17 @@ class ScriptInfo(object):
|
|||||||
""" Script info detection plugins, in order of priority """
|
""" Script info detection plugins, in order of priority """
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename):
|
def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, log_history=False):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if 'IPython' in sys.modules:
|
if 'IPython' in sys.modules:
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
from IPython import get_ipython
|
from IPython import get_ipython
|
||||||
if get_ipython():
|
if get_ipython():
|
||||||
_JupyterObserver.observer(jupyter_notebook_filename)
|
_JupyterObserver.observer(jupyter_notebook_filename, log_history)
|
||||||
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
|
get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync)
|
||||||
|
if log_history:
|
||||||
|
get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -377,22 +435,26 @@ class ScriptInfo(object):
|
|||||||
# we can safely assume that we can import the notebook package here
|
# we can safely assume that we can import the notebook package here
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
from notebook.notebookapp import list_running_servers
|
from notebook.notebookapp import list_running_servers
|
||||||
import requests
|
import requests
|
||||||
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
|
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
server_info = next(list_running_servers())
|
server_info = next(list_running_servers())
|
||||||
except Exception:
|
except Exception:
|
||||||
# on some jupyter notebook versions this function can crash on parsing the json file,
|
# on some jupyter notebook versions this function can crash on parsing the json file,
|
||||||
# we will parse it manually here
|
# we will parse it manually here
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
import ipykernel
|
import ipykernel
|
||||||
from glob import glob
|
from glob import glob
|
||||||
import json
|
import json
|
||||||
for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), 'nbserver-*.json')):
|
for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), 'nbserver-*.json')):
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
with open(f, 'r') as json_data:
|
with open(f, 'r') as json_data:
|
||||||
server_info = json.load(json_data)
|
server_info = json.load(json_data)
|
||||||
except:
|
except Exception:
|
||||||
server_info = None
|
server_info = None
|
||||||
if server_info:
|
if server_info:
|
||||||
break
|
break
|
||||||
@ -403,6 +465,7 @@ class ScriptInfo(object):
|
|||||||
except requests.exceptions.SSLError:
|
except requests.exceptions.SSLError:
|
||||||
# disable SSL check warning
|
# disable SSL check warning
|
||||||
from urllib3.exceptions import InsecureRequestWarning
|
from urllib3.exceptions import InsecureRequestWarning
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
|
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
|
||||||
# fire request
|
# fire request
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
@ -428,6 +491,7 @@ class ScriptInfo(object):
|
|||||||
# check if this is google.colab, then there is no local file
|
# check if this is google.colab, then there is no local file
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
from IPython import get_ipython
|
from IPython import get_ipython
|
||||||
if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded:
|
if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded:
|
||||||
is_google_colab = True
|
is_google_colab = True
|
||||||
@ -435,7 +499,10 @@ class ScriptInfo(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if is_google_colab:
|
if is_google_colab:
|
||||||
script_entry_point = notebook_name
|
script_entry_point = str(notebook_name or 'notebook').replace(
|
||||||
|
'>', '_').replace('<', '_').replace('.ipynb', '.py')
|
||||||
|
if not script_entry_point.lower().endswith('.py'):
|
||||||
|
script_entry_point += '.py'
|
||||||
local_ipynb_file = None
|
local_ipynb_file = None
|
||||||
else:
|
else:
|
||||||
# always slash, because this is from uri (so never backslash not even oon windows)
|
# always slash, because this is from uri (so never backslash not even oon windows)
|
||||||
@ -457,7 +524,7 @@ class ScriptInfo(object):
|
|||||||
|
|
||||||
# install the post store hook,
|
# install the post store hook,
|
||||||
# notice that if we do not have a local file we serialize/write every time the entire notebook
|
# notice that if we do not have a local file we serialize/write every time the entire notebook
|
||||||
cls._jupyter_install_post_store_hook(local_ipynb_file)
|
cls._jupyter_install_post_store_hook(local_ipynb_file, is_google_colab)
|
||||||
|
|
||||||
return script_entry_point
|
return script_entry_point
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -641,3 +708,135 @@ class ScriptInfo(object):
|
|||||||
class ScriptInfoResult(object):
|
class ScriptInfoResult(object):
|
||||||
script = attr.ib(default=None)
|
script = attr.ib(default=None)
|
||||||
warning_messages = attr.ib(factory=list)
|
warning_messages = attr.ib(factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class _JupyterHistoryLogger(object):
|
||||||
|
_reg_replace_ipython = r'\n([ \t]*)get_ipython\(\)'
|
||||||
|
_reg_replace_magic = r'\n([ \t]*)%'
|
||||||
|
_reg_replace_bang = r'\n([ \t]*)!'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._exception_raised = False
|
||||||
|
self._cells_code = {}
|
||||||
|
self._counter = 0
|
||||||
|
self._ip = None
|
||||||
|
self._current_cell = None
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
self._replace_ipython_pattern = re.compile(self._reg_replace_ipython)
|
||||||
|
self._replace_magic_pattern = re.compile(self._reg_replace_magic)
|
||||||
|
self._replace_bang_pattern = re.compile(self._reg_replace_bang)
|
||||||
|
except Exception:
|
||||||
|
self._replace_ipython_pattern = None
|
||||||
|
self._replace_magic_pattern = None
|
||||||
|
self._replace_bang_pattern = None
|
||||||
|
|
||||||
|
def hook(self, ip=None):
|
||||||
|
if not ip:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# noinspection PyPackageRequirements
|
||||||
|
from IPython import get_ipython
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
self._ip = get_ipython()
|
||||||
|
else:
|
||||||
|
self._ip = ip
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# if this is colab, the callbacks do not contain the raw_cell content, so we have to patch it
|
||||||
|
if 'google.colab' in self._ip.extension_manager.loaded:
|
||||||
|
self._ip._org_run_cell = self._ip.run_cell
|
||||||
|
self._ip.run_cell = partial(self._patched_run_cell, self._ip)
|
||||||
|
except Exception as ex:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# start with the current history
|
||||||
|
self._initialize_history()
|
||||||
|
self._ip.events.register('post_run_cell', self._post_cell_callback)
|
||||||
|
self._ip.events.register('pre_run_cell', self._pre_cell_callback)
|
||||||
|
self._ip.set_custom_exc((Exception,), self._exception_callback)
|
||||||
|
|
||||||
|
def _patched_run_cell(self, shell, *args, **kwargs):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
raw_cell = kwargs.get('raw_cell') or args[0]
|
||||||
|
self._current_cell = raw_cell
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return shell._org_run_cell(*args, **kwargs)
|
||||||
|
|
||||||
|
def history(self, filename):
|
||||||
|
with open(filename, 'wt') as f:
|
||||||
|
for k, v in sorted(self._cells_code.items(), key=lambda p: p[0]):
|
||||||
|
f.write(v)
|
||||||
|
|
||||||
|
def history_to_str(self):
|
||||||
|
# return a pair: (history as str, current cell if we are in still in cell execution otherwise None)
|
||||||
|
return '\n'.join(v for k, v in sorted(self._cells_code.items(), key=lambda p: p[0])), self._current_cell
|
||||||
|
|
||||||
|
# noinspection PyUnusedLocal
|
||||||
|
def _exception_callback(self, shell, etype, value, tb, tb_offset=None):
|
||||||
|
self._exception_raised = True
|
||||||
|
return shell.showtraceback()
|
||||||
|
|
||||||
|
def _pre_cell_callback(self, *args, **_):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if args:
|
||||||
|
self._current_cell = args[0].raw_cell
|
||||||
|
# we might have this value from somewhere else
|
||||||
|
if self._current_cell:
|
||||||
|
self._current_cell = self._conform_code(self._current_cell, replace_magic_bang=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _post_cell_callback(self, *_, **__):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
self._current_cell = None
|
||||||
|
if self._exception_raised:
|
||||||
|
# do nothing
|
||||||
|
self._exception_raised = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self._exception_raised = False
|
||||||
|
# add the cell history
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
cell_code = '\n' + self._ip.history_manager.input_hist_parsed[-1]
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
# fix magic / bang in code
|
||||||
|
cell_code = self._conform_code(cell_code)
|
||||||
|
|
||||||
|
self._cells_code[self._counter] = cell_code
|
||||||
|
self._counter += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _initialize_history(self):
|
||||||
|
# only once
|
||||||
|
if -1 in self._cells_code:
|
||||||
|
return
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
cell_code = '\n' + '\n'.join(self._ip.history_manager.input_hist_parsed[:-1])
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
cell_code = self._conform_code(cell_code)
|
||||||
|
self._cells_code[-1] = cell_code
|
||||||
|
|
||||||
|
def _conform_code(self, cell_code, replace_magic_bang=False):
|
||||||
|
# fix magic / bang in code
|
||||||
|
if self._replace_ipython_pattern:
|
||||||
|
cell_code = self._replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', cell_code)
|
||||||
|
if replace_magic_bang and self._replace_magic_pattern and self._replace_bang_pattern:
|
||||||
|
cell_code = self._replace_magic_pattern.sub(r'\n# \g<1>%', cell_code)
|
||||||
|
cell_code = self._replace_bang_pattern.sub(r'\n# \g<1>!', cell_code)
|
||||||
|
return cell_code
|
||||||
|
Loading…
Reference in New Issue
Block a user