import os import sys from tempfile import mkstemp import attr import collections import logging from furl import furl from pathlib2 import Path from threading import Thread, Event from ....backend_api import Session from ....debugging import get_logger from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult _logger = get_logger("Repository Detection") class ScriptInfoError(Exception): pass class ScriptRequirements(object): def __init__(self, root_folder): self._root_folder = root_folder @staticmethod def get_installed_pkgs_detail(reqs): """ HACK: bugfix of the original pigar get_installed_pkgs_detail Get mapping for import top level name and install package name with version. """ mapping = dict() for path in sys.path: if os.path.isdir(path) and path.rstrip('/').endswith( ('site-packages', 'dist-packages')): new_mapping = reqs._search_path(path) # BUGFIX: # override with previous, just like python resolves imports, the first match is the one used. # unlike the original implementation, where the last one is used. new_mapping.update(mapping) mapping = new_mapping return mapping def get_requirements(self): try: from pigar import reqs reqs.project_import_modules = ScriptRequirements._patched_project_import_modules from pigar.__main__ import GenerateReqs from pigar.log import logger logger.setLevel(logging.WARNING) try: # first try our version, if we fail revert to the internal implantation installed_pkgs = self.get_installed_pkgs_detail(reqs) except Exception: installed_pkgs = reqs.get_installed_pkgs_detail() gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs, ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints']) reqs, try_imports, guess = gr.extract_reqs() return self.create_requirements_txt(reqs) except Exception: return '' @staticmethod def _patched_project_import_modules(project_path, ignores): """ copied form pigar req.project_import_modules patching, os.getcwd() is incorrectly used """ from pigar.modules import ImportedModules from pigar.reqs import file_import_modules modules = ImportedModules() try_imports = set() local_mods = list() cur_dir = project_path # os.getcwd() ignore_paths = collections.defaultdict(set) if not ignores: ignore_paths[project_path].add('.git') else: for path in ignores: parent_dir = os.path.dirname(path) ignore_paths[parent_dir].add(os.path.basename(path)) for dirpath, dirnames, files in os.walk(project_path, followlinks=True): if dirpath in ignore_paths: dirnames[:] = [d for d in dirnames if d not in ignore_paths[dirpath]] py_files = list() for fn in files: # C extension. if fn.endswith('.so'): local_mods.append(fn[:-3]) # Normal Python file. if fn.endswith('.py'): local_mods.append(fn[:-3]) py_files.append(fn) if '__init__.py' in files: local_mods.append(os.path.basename(dirpath)) for file in py_files: fpath = os.path.join(dirpath, file) fake_path = fpath.split(cur_dir)[1][1:] with open(fpath, 'rb') as f: fmodules, try_ipts = file_import_modules(fake_path, f.read()) modules |= fmodules try_imports |= try_ipts # hack: forcefully insert storage modules if we have them # noinspection PyBroadException try: import boto3 modules.add('boto3', 'trains.storage', 0) except Exception: pass # noinspection PyBroadException try: from google.cloud import storage modules.add('google_cloud_storage', 'trains.storage', 0) except Exception: pass # noinspection PyBroadException try: from azure.storage.blob import ContentSettings modules.add('azure_storage_blob', 'trains.storage', 0) except Exception: pass return modules, try_imports, local_mods @staticmethod def create_requirements_txt(reqs): # write requirements.txt # python version header requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' # requirement summary requirements_txt += '\n' for k, v in reqs.sorted_items(): # requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) if k == '-e': requirements_txt += '{0} {1}\n'.format(k, v.version) elif v: requirements_txt += '{0} {1} {2}\n'.format(k, '==', v.version) else: requirements_txt += '{0}\n'.format(k) # requirements details (in comments) requirements_txt += '\n' + \ '# Detailed import analysis\n' \ '# **************************\n' for k, v in reqs.sorted_items(): requirements_txt += '\n' if k == '-e': requirements_txt += '# IMPORT PACKAGE {0} {1}\n'.format(k, v.version) else: requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k) requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) return requirements_txt class _JupyterObserver(object): _thread = None _exit_event = Event() _sync_event = Event() _sample_frequency = 30. _first_sample_frequency = 3. @classmethod def observer(cls, jupyter_notebook_filename): if cls._thread is not None: # order of signaling is important! cls._exit_event.set() cls._sync_event.set() cls._thread.join() cls._sync_event.clear() cls._exit_event.clear() cls._thread = Thread(target=cls._daemon, args=(jupyter_notebook_filename, )) cls._thread.daemon = True cls._thread.start() @classmethod def signal_sync(cls, *_): cls._sync_event.set() @classmethod def _daemon(cls, jupyter_notebook_filename): from trains import Task # load jupyter notebook package # noinspection PyBroadException try: from nbconvert.exporters.script import ScriptExporter _script_exporter = ScriptExporter() except Exception: return # load pigar # noinspection PyBroadException try: from pigar.reqs import get_installed_pkgs_detail, file_import_modules from pigar.modules import ReqsModules from pigar.log import logger logger.setLevel(logging.WARNING) except Exception: file_import_modules = None # load IPython # noinspection PyBroadException try: from IPython import get_ipython except Exception: # should not happen get_ipython = None # setup local notebook files if jupyter_notebook_filename: notebook = Path(jupyter_notebook_filename) local_jupyter_filename = jupyter_notebook_filename else: notebook = None fd, local_jupyter_filename = mkstemp(suffix='.ipynb') os.close(fd) last_update_ts = None counter = 0 prev_script_hash = None # main observer loop while True: # wait for timeout or sync event cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency) # check if we need to exit if cls._exit_event.wait(timeout=0.): return cls._sync_event.clear() counter += 1 # noinspection PyBroadException try: # if there is no task connected, do nothing task = Task.current_task() if not task: continue # if we have a local file: if notebook: if not notebook.exists(): continue # check if notebook changed if last_update_ts is not None and notebook.stat().st_mtime - last_update_ts <= 0: continue last_update_ts = notebook.stat().st_mtime else: # serialize notebook to a temp file # noinspection PyBroadException try: get_ipython().run_line_magic('notebook', local_jupyter_filename) except Exception as ex: continue # get notebook python script script_code, resources = _script_exporter.from_filename(local_jupyter_filename) current_script_hash = hash(script_code) if prev_script_hash and prev_script_hash == current_script_hash: continue requirements_txt = '' # parse jupyter python script and prepare pip requirements (pigar) # if backend supports requirements if file_import_modules and Session.check_min_api_version('2.2'): fmodules, _ = file_import_modules(notebook.parts[-1], script_code) installed_pkgs = get_installed_pkgs_detail() reqs = ReqsModules() for name in fmodules: if name in installed_pkgs: pkg_name, version = installed_pkgs[name] reqs.add(pkg_name, version, fmodules[name]) requirements_txt = ScriptRequirements.create_requirements_txt(reqs) # update script prev_script_hash = current_script_hash data_script = task.data.script data_script.diff = script_code data_script.requirements = {'pip': requirements_txt} task._update_script(script=data_script) # update requirements task._update_requirements(requirements=requirements_txt) except Exception: pass class ScriptInfo(object): plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()] """ Script info detection plugins, in order of priority """ @classmethod def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename): # noinspection PyBroadException try: if 'IPython' in sys.modules: from IPython import get_ipython if get_ipython(): _JupyterObserver.observer(jupyter_notebook_filename) get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) except Exception: pass @classmethod def _get_jupyter_notebook_filename(cls): if not (sys.argv[0].endswith(os.path.sep+'ipykernel_launcher.py') or sys.argv[0].endswith(os.path.join(os.path.sep, 'ipykernel', '__main__.py'))) \ or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'): return None # we can safely assume that we can import the notebook package here # noinspection PyBroadException try: from notebook.notebookapp import list_running_servers import requests current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '') server_info = next(list_running_servers()) r = requests.get( url=server_info['url'] + 'api/sessions', headers={'Authorization': 'token {}'.format(server_info.get('token', '')), }) r.raise_for_status() notebooks = r.json() cur_notebook = None for n in notebooks: if n['kernel']['id'] == current_kernel: cur_notebook = n break notebook_path = cur_notebook['notebook'].get('path', '') notebook_name = cur_notebook['notebook'].get('name', '') is_google_colab = False # check if this is google.colab, then there is no local file # noinspection PyBroadException try: from IPython import get_ipython if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded: is_google_colab = True except Exception: pass if is_google_colab: script_entry_point = notebook_name local_ipynb_file = None else: # always slash, because this is from uri (so never backslash not even oon windows) entry_point_filename = notebook_path.split('/')[-1] # now we should try to find the actual file entry_point = (Path.cwd() / entry_point_filename).absolute() if not entry_point.is_file(): entry_point = (Path.cwd() / notebook_path).absolute() # get local ipynb for observer local_ipynb_file = entry_point.as_posix() # now replace the .ipynb with .py # we assume we will have that file available with the Jupyter notebook plugin entry_point = entry_point.with_suffix('.py') script_entry_point = entry_point.as_posix() # install the post store hook, # notice that if we do not have a local file we serialize/write every time the entire notebook cls._jupyter_install_post_store_hook(local_ipynb_file) return script_entry_point except Exception: return None @classmethod def _get_entry_point(cls, repo_root, script_path): repo_root = Path(repo_root).absolute() try: # Use os.path.relpath as it calculates up dir movements (../) entry_point = os.path.relpath(str(script_path), str(Path.cwd())) except ValueError: # Working directory not under repository root entry_point = script_path.relative_to(repo_root) return Path(entry_point).as_posix() @classmethod def _get_working_dir(cls, repo_root): repo_root = Path(repo_root).absolute() try: return Path.cwd().relative_to(repo_root).as_posix() except ValueError: # Working directory not under repository root return os.path.curdir @classmethod def _get_script_code(cls, script_path): # noinspection PyBroadException try: with open(script_path, 'r') as f: script_code = f.read() return script_code except Exception: pass return '' @classmethod def _get_script_info(cls, filepath, check_uncommitted=True, create_requirements=True, log=None): jupyter_filepath = cls._get_jupyter_notebook_filename() if jupyter_filepath: script_path = Path(os.path.normpath(jupyter_filepath)).absolute() else: script_path = Path(os.path.normpath(filepath)).absolute() if not script_path.is_file(): raise ScriptInfoError( "Script file [{}] could not be found".format(filepath) ) script_dir = script_path.parent def _log(msg, *args, **kwargs): if not log: return log.warning( "Failed auto-detecting task repository: {}".format( msg.format(*args, **kwargs) ) ) plugin = next((p for p in cls.plugins if p.exists(script_dir)), None) repo_info = DetectionResult() if not plugin: log.info("No repository found, storing script code instead") else: try: repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted) except Exception as ex: _log("no info for {} ({})", script_dir, ex) else: if repo_info.is_empty(): _log("no info for {}", script_dir) repo_root = repo_info.root or script_dir working_dir = cls._get_working_dir(repo_root) entry_point = cls._get_entry_point(repo_root, script_path) if check_uncommitted: diff = cls._get_script_code(script_path.as_posix()) \ if not plugin or not repo_info.commit else repo_info.diff else: diff = '' # if this is not jupyter, get the requirements.txt requirements = '' # create requirements if backend supports requirements # if jupyter is present, requirements will be created in the background, when saving a snapshot if not jupyter_filepath and Session.check_min_api_version('2.2'): script_requirements = ScriptRequirements(Path(repo_root).as_posix()) if create_requirements: requirements = script_requirements.get_requirements() else: script_requirements = None script_info = dict( repository=furl(repo_info.url).remove(username=True, password=True).tostr(), branch=repo_info.branch, version_num=repo_info.commit, entry_point=entry_point, working_dir=working_dir, diff=diff, requirements={'pip': requirements} if requirements else None, ) messages = [] if repo_info.modified: messages.append( "======> WARNING! UNCOMMITTED CHANGES IN REPOSITORY {} <======".format( script_info.get("repository", "") ) ) if not any(script_info.values()): script_info = None return (ScriptInfoResult(script=script_info, warning_messages=messages), script_requirements) @classmethod def get(cls, filepath=sys.argv[0], check_uncommitted=True, create_requirements=True, log=None): try: return cls._get_script_info( filepath=filepath, check_uncommitted=check_uncommitted, create_requirements=create_requirements, log=log) except Exception as ex: if log: log.warning("Failed auto-detecting task repository: {}".format(ex)) return ScriptInfoResult(), None @attr.s class ScriptInfoResult(object): script = attr.ib(default=None) warning_messages = attr.ib(factory=list)