From e6e628517a865c5c1195ef5d661efc54437ac9e7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 21 Apr 2024 12:28:52 +0300 Subject: [PATCH] Black formatting --- .../backend_interface/task/repo/scriptinfo.py | 531 ++++++++++-------- 1 file changed, 307 insertions(+), 224 deletions(-) diff --git a/clearml/backend_interface/task/repo/scriptinfo.py b/clearml/backend_interface/task/repo/scriptinfo.py index 69bd3739..ca5e5713 100644 --- a/clearml/backend_interface/task/repo/scriptinfo.py +++ b/clearml/backend_interface/task/repo/scriptinfo.py @@ -26,9 +26,9 @@ class ScriptInfoError(Exception): class ScriptRequirements(object): - _detailed_import_report = deferred_config('development.detailed_import_report', False) + _detailed_import_report = deferred_config("development.detailed_import_report", False) _max_requirements_size = 512 * 1024 - _packages_remove_version = ('setuptools', ) + _packages_remove_version = ("setuptools",) _ignore_packages = set() @classmethod @@ -38,19 +38,24 @@ class ScriptRequirements(object): def __init__(self, root_folder): self._root_folder = root_folder - def get_requirements(self, entry_point_filename=None, add_missing_installed_packages=False, - detailed_req_report=None): + def get_requirements( + self, entry_point_filename=None, add_missing_installed_packages=False, detailed_req_report=None + ): # noinspection PyBroadException try: from ....utilities.pigar.reqs import get_installed_pkgs_detail from ....utilities.pigar.__main__ import GenerateReqs - installed_pkgs = self._remove_package_versions( - get_installed_pkgs_detail(), self._packages_remove_version) - gr = GenerateReqs(save_path='', project_path=self._root_folder, installed_pkgs=installed_pkgs, - ignores=['.git', '.hg', '.idea', '__pycache__', '.ipynb_checkpoints', - 'site-packages', 'dist-packages']) + + installed_pkgs = self._remove_package_versions(get_installed_pkgs_detail(), self._packages_remove_version) + gr = GenerateReqs( + save_path="", + project_path=self._root_folder, + installed_pkgs=installed_pkgs, + ignores=[".git", ".hg", ".idea", "__pycache__", ".ipynb_checkpoints", "site-packages", "dist-packages"], + ) reqs, try_imports, guess, local_pks = gr.extract_reqs( - module_callback=ScriptRequirements.add_trains_used_packages, entry_point_filename=entry_point_filename) + module_callback=ScriptRequirements.add_trains_used_packages, entry_point_filename=entry_point_filename + ) if add_missing_installed_packages and guess: for k in guess: if k not in reqs: @@ -58,7 +63,7 @@ class ScriptRequirements(object): return self.create_requirements_txt(reqs, local_pks, detailed=detailed_req_report) except Exception as ex: self._get_logger().warning("Failed auto-generating package requirements: {}".format(ex)) - return '', '' + return "", "" @staticmethod def add_trains_used_packages(modules): @@ -67,66 +72,72 @@ class ScriptRequirements(object): try: # noinspection PyPackageRequirements,PyUnresolvedReferences import boto3 # noqa: F401 - modules.add('boto3', 'clearml.storage', 0) + + modules.add("boto3", "clearml.storage", 0) except Exception: pass # noinspection PyBroadException try: # noinspection PyPackageRequirements,PyUnresolvedReferences from google.cloud import storage # noqa: F401 - modules.add('google_cloud_storage', 'clearml.storage', 0) + + modules.add("google_cloud_storage", "clearml.storage", 0) except Exception: pass # noinspection PyBroadException try: # noinspection PyPackageRequirements,PyUnresolvedReferences from azure.storage.blob import ContentSettings # noqa: F401 - modules.add('azure_storage_blob', 'clearml.storage', 0) + + modules.add("azure_storage_blob", "clearml.storage", 0) except Exception: pass # bugfix, replace sklearn with scikit-learn name - if 'sklearn' in modules: - sklearn = modules.pop('sklearn', {}) + if "sklearn" in modules: + sklearn = modules.pop("sklearn", {}) for fname, lines in sklearn.items(): - modules.add('scikit_learn', fname, lines) + modules.add("scikit_learn", fname, lines) # bugfix, replace sklearn with scikit-learn name - if 'skimage' in modules: - skimage = modules.pop('skimage', {}) + if "skimage" in modules: + skimage = modules.pop("skimage", {}) for fname, lines in skimage.items(): - modules.add('scikit_image', fname, lines) + modules.add("scikit_image", fname, lines) - if 'tensorflow-intel' in modules: - tfmodule = modules.pop('tensorflow-intel', {}) + if "tensorflow-intel" in modules: + tfmodule = modules.pop("tensorflow-intel", {}) for fname, lines in tfmodule.items(): - modules.add('tensorflow', fname, lines) + modules.add("tensorflow", fname, lines) # if we have torch, and it supports tensorboard, we should add that as well # (because it will not be detected automatically) - if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules: + if "torch" in modules and "tensorboard" not in modules and "tensorboardX" not in modules: # noinspection PyBroadException try: # see if this version of torch support tensorboard # noinspection PyPackageRequirements,PyUnresolvedReferences import torch.utils.tensorboard # noqa: F401 + # noinspection PyPackageRequirements,PyUnresolvedReferences import tensorboard # noqa: F401 - modules.add('tensorboard', 'torch', 0) + + modules.add("tensorboard", "torch", 0) except Exception: pass # remove setuptools, we should not specify this module version. It is installed by default - if 'setuptools' in modules: - modules.pop('setuptools', {}) + if "setuptools" in modules: + modules.pop("setuptools", {}) # add forced requirements: # noinspection PyBroadException try: from ..task import Task + # noinspection PyProtectedMember for package, version in Task._force_requirements.items(): - modules.add(package, 'clearml', 0) + modules.add(package, "clearml", 0) except Exception: pass @@ -140,42 +151,42 @@ class ScriptRequirements(object): # noinspection PyBroadException try: - conda_requirements = '' - conda_prefix = os.environ.get('CONDA_PREFIX') + conda_requirements = "" + conda_prefix = os.environ.get("CONDA_PREFIX") if conda_prefix and not conda_prefix.endswith(os.path.sep): conda_prefix += os.path.sep if conda_prefix and sys.executable.startswith(conda_prefix): - conda_packages_json = get_command_output(['conda', 'list', '--json']) + conda_packages_json = get_command_output(["conda", "list", "--json"]) conda_packages_json = json.loads(conda_packages_json) reqs_lower = {k.lower(): (k, v) for k, v in reqs.items()} for r in conda_packages_json: # the exception is cudatoolkit which we want to log anyhow - if r.get('name') == 'cudatoolkit' and r.get('version'): - conda_requirements += '{0} {1} {2}\n'.format(r.get('name'), '==', r.get('version')) + if r.get("name") == "cudatoolkit" and r.get("version"): + conda_requirements += "{0} {1} {2}\n".format(r.get("name"), "==", r.get("version")) continue # check if this is a pypi package, if it is, leave it outside - if not r.get('channel') or r.get('channel') == 'pypi': + if not r.get("channel") or r.get("channel") == "pypi": continue # check if we have it in our required packages - name = r['name'].lower() + name = r["name"].lower() # hack support pytorch/torch different naming convention - if name == 'pytorch': - name = 'torch' + if name == "pytorch": + name = "torch" k, v = None, None if name in reqs_lower: k, v = reqs_lower.get(name, (None, None)) else: - name = name.replace('-', '_') + name = name.replace("-", "_") if name in reqs_lower: k, v = reqs_lower.get(name, (None, None)) if k and v is not None: if v.version: - conda_requirements += '{0} {1} {2}\n'.format(k, '==', v.version) + conda_requirements += "{0} {1} {2}\n".format(k, "==", v.version) else: - conda_requirements += '{0}\n'.format(k) + conda_requirements += "{0}\n".format(k) except Exception: - conda_requirements = '' + conda_requirements = "" # add forced requirements: forced_packages = {} @@ -183,6 +194,7 @@ class ScriptRequirements(object): # noinspection PyBroadException try: from ..task import Task + # noinspection PyProtectedMember forced_packages = copy(Task._force_requirements) # noinspection PyProtectedMember @@ -191,18 +203,18 @@ class ScriptRequirements(object): pass # python version header - requirements_txt = '# Python ' + sys.version.replace('\n', ' ').replace('\r', ' ') + '\n' + requirements_txt = "# Python " + sys.version.replace("\n", " ").replace("\r", " ") + "\n" if local_pks: - requirements_txt += '\n# Local modules found - skipping:\n' + requirements_txt += "\n# Local modules found - skipping:\n" for k, v in local_pks.sorted_items(): if v.version: - requirements_txt += '# {0} == {1}\n'.format(k, v.version) + requirements_txt += "# {0} == {1}\n".format(k, v.version) else: - requirements_txt += '# {0}\n'.format(k) + requirements_txt += "# {0}\n".format(k) # requirement summary - requirements_txt += '\n' + requirements_txt += "\n" for k, v in reqs.sorted_items(): if k in ignored_packages or k.lower() in ignored_packages: continue @@ -220,48 +232,50 @@ class ScriptRequirements(object): requirements_txt_packages_only = requirements_txt if detailed: - requirements_txt_packages_only = \ - requirements_txt + '\n# Skipping detailed import analysis, it is too large\n' + requirements_txt_packages_only = ( + requirements_txt + "\n# Skipping detailed import analysis, it is too large\n" + ) # requirements details (in comments) - requirements_txt += '\n' + \ - '# Detailed import analysis\n' \ - '# **************************\n' + requirements_txt += "\n# Detailed import analysis\n# **************************\n" if local_pks: for k, v in local_pks.sorted_items(): - requirements_txt += '\n' - requirements_txt += '# IMPORT LOCAL PACKAGE {0}\n'.format(k) - requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) + requirements_txt += "\n" + requirements_txt += "# IMPORT LOCAL PACKAGE {0}\n".format(k) + requirements_txt += "".join(["# {0}\n".format(c) for c in v.comments.sorted_items()]) for k, v in reqs.sorted_items(): if not v: continue - requirements_txt += '\n' - if k == '-e': - requirements_txt += '# IMPORT PACKAGE {0} {1}\n'.format(k, v.version) + requirements_txt += "\n" + if k == "-e": + requirements_txt += "# IMPORT PACKAGE {0} {1}\n".format(k, v.version) else: - requirements_txt += '# IMPORT PACKAGE {0}\n'.format(k) - requirements_txt += ''.join(['# {0}\n'.format(c) for c in v.comments.sorted_items()]) + requirements_txt += "# IMPORT PACKAGE {0}\n".format(k) + requirements_txt += "".join(["# {0}\n".format(c) for c in v.comments.sorted_items()]) # make sure we do not exceed the size a size limit - return (requirements_txt if len(requirements_txt) < ScriptRequirements._max_requirements_size - else requirements_txt_packages_only, - conda_requirements) + return ( + requirements_txt + if len(requirements_txt) < ScriptRequirements._max_requirements_size + else requirements_txt_packages_only, + conda_requirements, + ) @staticmethod def _make_req_line(k, version): - requirements_txt = '' - if k == '-e' and version: - requirements_txt += '{0}\n'.format(version) - elif k.startswith('-e '): - requirements_txt += '{0} {1}\n'.format(k.replace('-e ', '', 1), version or '') - elif version and str(version or ' ').strip()[0].isdigit(): - requirements_txt += '{0} {1} {2}\n'.format(k, '==', version) + requirements_txt = "" + if k == "-e" and version: + requirements_txt += "{0}\n".format(version) + elif k.startswith("-e "): + requirements_txt += "{0} {1}\n".format(k.replace("-e ", "", 1), version or "") + elif version and str(version or " ").strip()[0].isdigit(): + requirements_txt += "{0} {1} {2}\n".format(k, "==", version) elif version and str(version).strip(): - requirements_txt += '{0} {1}\n'.format(k, version) + requirements_txt += "{0} {1}\n".format(k, version) else: - requirements_txt += '{0}\n'.format(k) + requirements_txt += "{0}\n".format(k) return requirements_txt @staticmethod @@ -269,7 +283,8 @@ class ScriptRequirements(object): def _internal(_installed_pkgs): return { k: (v[0], None if str(k) in package_names_to_remove_version else v[1]) - if not isinstance(v, dict) else _internal(v) + if not isinstance(v, dict) + else _internal(v) for k, v in _installed_pkgs.items() } @@ -280,10 +295,10 @@ class _JupyterObserver(object): _thread = None _exit_event = None _sync_event = None - _sample_frequency = 30. - _first_sample_frequency = 3. + _sample_frequency = 30.0 + _first_sample_frequency = 3.0 _jupyter_history_logger = None - _store_notebook_artifact = deferred_config('development.store_jupyter_notebook_artifact', True) + _store_notebook_artifact = deferred_config("development.store_jupyter_notebook_artifact", True) @classmethod def _get_logger(cls): @@ -337,6 +352,7 @@ class _JupyterObserver(object): try: # noinspection PyPackageRequirements from nbconvert.exporters import PythonExporter # noqa + _script_exporter = PythonExporter() except Exception: _script_exporter = None @@ -344,10 +360,11 @@ class _JupyterObserver(object): if _script_exporter is None: # noinspection PyPackageRequirements from nbconvert.exporters.script import ScriptExporter # noqa + _script_exporter = ScriptExporter() except Exception as ex: - cls._get_logger().warning('Could not read Jupyter Notebook: {}'.format(ex)) + cls._get_logger().warning("Could not read Jupyter Notebook: {}".format(ex)) if isinstance(ex, ImportError): module_name = getattr(ex, "name", None) if module_name: @@ -362,6 +379,7 @@ class _JupyterObserver(object): from ....utilities.pigar.reqs import get_installed_pkgs_detail, file_import_modules from ....utilities.pigar.modules import ReqsModules from ....utilities.pigar.log import logger + logger.setLevel(logging.WARNING) except Exception: file_import_modules = None @@ -380,7 +398,7 @@ class _JupyterObserver(object): local_jupyter_filename = jupyter_notebook_filename else: notebook = None - folder = mkdtemp(suffix='.notebook') + folder = mkdtemp(suffix=".notebook") if notebook_name.endswith(".py"): notebook_name = notebook_name.replace(".py", ".ipynb") if not notebook_name.endswith(".ipynb"): @@ -395,21 +413,23 @@ class _JupyterObserver(object): # noinspection PyBroadException try: from ....version import __version__ - our_module = cls.__module__.split('.')[0], __version__ + + our_module = cls.__module__.split(".")[0], __version__ except Exception: our_module = None # noinspection PyBroadException try: import re - replace_ipython_pattern = re.compile(r'\n([ \t]*)get_ipython\([ \t]*\)') - replace_ipython_display_pattern = re.compile(r'\n([ \t]*)display\(') + + replace_ipython_pattern = re.compile(r"\n([ \t]*)get_ipython\([ \t]*\)") + replace_ipython_display_pattern = re.compile(r"\n([ \t]*)display\(") except Exception: replace_ipython_pattern = None replace_ipython_display_pattern = None # main observer loop, check if we need to exit - while not cls._exit_event.wait(timeout=0.): + while not cls._exit_event.wait(timeout=0.0): # wait for timeout or sync event cls._sync_event.wait(cls._sample_frequency if counter else cls._first_sample_frequency) @@ -459,21 +479,22 @@ class _JupyterObserver(object): os.unlink(local_jupyter_filename) except Exception: pass - get_ipython().run_line_magic('history', '-t -f {}'.format(local_jupyter_filename)) - with open(local_jupyter_filename, 'r') as f: + get_ipython().run_line_magic("history", "-t -f {}".format(local_jupyter_filename)) + with open(local_jupyter_filename, "r") as f: script_code = f.read() # load the modules from ....utilities.pigar.modules import ImportedModules + fmodules = ImportedModules() - for nm in set([str(m).split('.')[0] for m in sys.modules]): - fmodules.add(nm, 'notebook', 0) + for nm in set([str(m).split(".")[0] for m in sys.modules]): + fmodules.add(nm, "notebook", 0) except Exception: continue if _script_exporter is None: - current_script_hash = 'error_notebook_not_found.py' - requirements_txt = '' - conda_requirements = '' + current_script_hash = "error_notebook_not_found.py" + requirements_txt = "" + conda_requirements = "" else: # get notebook python script if script_code is None and local_jupyter_filename: @@ -481,44 +502,48 @@ class _JupyterObserver(object): if cls._store_notebook_artifact: # also upload the jupyter notebook as artifact task.upload_artifact( - name='notebook', + name="notebook", artifact_object=Path(local_jupyter_filename), - preview='See `notebook preview` artifact', - metadata={'UPDATE': datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}, + preview="See `notebook preview` artifact", + metadata={"UPDATE": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")}, wait_on_upload=True, ) # noinspection PyBroadException try: from nbconvert.exporters import HTMLExporter # noqa + html, _ = HTMLExporter().from_filename(filename=local_jupyter_filename) - local_html = Path(gettempdir()) / 'notebook_{}.html'.format(task.id) - with open(local_html.as_posix(), 'wt', encoding="utf-8") as f: + local_html = Path(gettempdir()) / "notebook_{}.html".format(task.id) + with open(local_html.as_posix(), "wt", encoding="utf-8") as f: f.write(html) task.upload_artifact( - name='notebook preview', artifact_object=local_html, - preview='Click `FILE PATH` link', - metadata={'UPDATE': datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}, + name="notebook preview", + artifact_object=local_html, + preview="Click `FILE PATH` link", + metadata={"UPDATE": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")}, delete_after_upload=True, wait_on_upload=True, ) except Exception: pass - current_script_hash = hash(script_code + (current_cell or '')) + current_script_hash = hash(script_code + (current_cell or "")) if prev_script_hash and prev_script_hash == current_script_hash: continue - requirements_txt = '' - conda_requirements = '' + requirements_txt = "" + conda_requirements = "" # parse jupyter python script and prepare pip requirements (pigar) # if backend supports requirements - if file_import_modules and Session.check_min_api_version('2.2'): + if file_import_modules and Session.check_min_api_version("2.2"): if fmodules is None: fmodules, _ = file_import_modules( - notebook.parts[-1] if notebook else 'notebook', script_code) + notebook.parts[-1] if notebook else "notebook", script_code + ) if current_cell: cell_fmodules, _ = file_import_modules( - notebook.parts[-1] if notebook else 'notebook', current_cell) + notebook.parts[-1] if notebook else "notebook", current_cell + ) # noinspection PyBroadException try: fmodules |= cell_fmodules @@ -526,7 +551,7 @@ class _JupyterObserver(object): pass # add current cell to the script if current_cell: - script_code += '\n' + current_cell + script_code += "\n" + current_cell fmodules = ScriptRequirements.add_trains_used_packages(fmodules) # noinspection PyUnboundLocalVariable installed_pkgs = get_installed_pkgs_detail() @@ -546,15 +571,15 @@ class _JupyterObserver(object): # we will not be able to run them anyhow # probably should be better dealt with, because multi line will break it if replace_ipython_pattern: - script_code = replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', script_code) + script_code = replace_ipython_pattern.sub(r"\n# \g<1>get_ipython()", script_code) if replace_ipython_display_pattern: - script_code = replace_ipython_display_pattern.sub(r'\n\g<1>print(', script_code) + script_code = replace_ipython_display_pattern.sub(r"\n\g<1>print(", script_code) # update script prev_script_hash = current_script_hash data_script = task.data.script data_script.diff = script_code - data_script.requirements = {'pip': requirements_txt, 'conda': conda_requirements} + data_script.requirements = {"pip": requirements_txt, "conda": conda_requirements} # noinspection PyProtectedMember task._update_script(script=data_script) # update requirements @@ -580,33 +605,36 @@ class ScriptInfo(object): def _jupyter_install_post_store_hook(cls, jupyter_notebook_filename, notebook_name=None, log_history=False): # noinspection PyBroadException try: - if 'IPython' in sys.modules: + if "IPython" in sys.modules: # noinspection PyPackageRequirements from IPython import get_ipython + if get_ipython(): _JupyterObserver.observer( - jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history) - get_ipython().events.register('pre_run_cell', _JupyterObserver.signal_sync) + jupyter_notebook_filename, notebook_name=notebook_name, log_history=log_history + ) + get_ipython().events.register("pre_run_cell", _JupyterObserver.signal_sync) if log_history: - get_ipython().events.register('post_run_cell', _JupyterObserver.signal_sync) + get_ipython().events.register("post_run_cell", _JupyterObserver.signal_sync) except Exception: pass @classmethod def _get_jupyter_notebook_filename(cls): # check if we are running in vscode, we have the jupyter notebook defined: - if 'IPython' in sys.modules: + if "IPython" in sys.modules: # noinspection PyBroadException try: from IPython import get_ipython # noqa + ip = get_ipython() # vscode-jupyter PR #8531 added this variable - local_ipynb_file = ip.__dict__.get('user_ns', {}).get('__vsc_ipynb_file__') if ip else None + local_ipynb_file = ip.__dict__.get("user_ns", {}).get("__vsc_ipynb_file__") if ip else None if local_ipynb_file: # now replace the .ipynb with .py # we assume we will have that file available for monitoring local_ipynb_file = Path(local_ipynb_file) - script_entry_point = local_ipynb_file.with_suffix('.py').as_posix() + script_entry_point = local_ipynb_file.with_suffix(".py").as_posix() # install the post store hook, # notice that if we do not have a local file we serialize/write every time the entire notebook @@ -635,6 +663,7 @@ class ScriptInfo(object): try: # noinspection PyPackageRequirements from notebook.notebookapp import list_running_servers # noqa <= Notebook v6 + # noinspection PyBroadException try: jupyter_servers += list(list_running_servers()) @@ -650,6 +679,7 @@ class ScriptInfo(object): try: # noinspection PyPackageRequirements from jupyter_server.serverapp import list_running_servers # noqa + # noinspection PyBroadException try: jupyter_servers += list(list_running_servers()) @@ -660,7 +690,7 @@ class ScriptInfo(object): except Exception: pass - current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '') + current_kernel = sys.argv[2].split(os.path.sep)[-1].replace("kernel-", "").replace(".json", "") notebook_path = None notebook_name = None @@ -669,40 +699,56 @@ class ScriptInfo(object): cookies = None password = None - if server_info and server_info.get('password'): + if server_info and server_info.get("password"): # we need to get the password from ....config import config - password = config.get('development.jupyter_server_password', '') + + password = config.get("development.jupyter_server_password", "") if not password: cls._get_logger().warning( - 'Password protected Jupyter Notebook server was found! ' - 'Add `sdk.development.jupyter_server_password=` to ~/clearml.conf') - return os.path.join(os.getcwd(), 'error_notebook_not_found.py') + "Password protected Jupyter Notebook server was found! " + "Add `sdk.development.jupyter_server_password=` to ~/clearml.conf" + ) + return os.path.join(os.getcwd(), "error_notebook_not_found.py") - r = requests.get(url=server_info['url'] + 'login') - cookies = {'_xsrf': r.cookies.get('_xsrf', '')} - r = requests.post(server_info['url'] + 'login?next', cookies=cookies, - data={'_xsrf': cookies['_xsrf'], 'password': password}) + r = requests.get(url=server_info["url"] + "login") + cookies = {"_xsrf": r.cookies.get("_xsrf", "")} + r = requests.post( + server_info["url"] + "login?next", + cookies=cookies, + data={"_xsrf": cookies["_xsrf"], "password": password}, + ) cookies.update(r.cookies) # get api token from ENV - if not defined then from server info - auth_token = os.getenv('JUPYTERHUB_API_TOKEN') or server_info.get('token') or '' + auth_token = os.getenv("JUPYTERHUB_API_TOKEN") or server_info.get("token") or "" try: r = requests.get( - url=server_info['url'] + 'api/sessions', cookies=cookies, - headers={'Authorization': 'token {}'.format(auth_token), }) + url=server_info["url"] + "api/sessions", + cookies=cookies, + headers={ + "Authorization": "token {}".format(auth_token), + }, + ) except requests.exceptions.SSLError: # disable SSL check warning from urllib3.exceptions import InsecureRequestWarning + # noinspection PyUnresolvedReferences requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) # fire request r = requests.get( - url=server_info['url'] + 'api/sessions', cookies=cookies, - headers={'Authorization': 'token {}'.format(auth_token), }, verify=False) + url=server_info["url"] + "api/sessions", + cookies=cookies, + headers={ + "Authorization": "token {}".format(auth_token), + }, + verify=False, + ) # enable SSL check warning import warnings - warnings.simplefilter('default', InsecureRequestWarning) + + warnings.simplefilter("default", InsecureRequestWarning) # send request to the jupyter server try: @@ -710,14 +756,17 @@ class ScriptInfo(object): except Exception as ex: # raise on last one only if server_index == len(jupyter_servers) - 1: - cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format( - ' [password={}]'.format(password) if server_info.get('password') else '', ex)) - return os.path.join(os.getcwd(), 'error_notebook_not_found.py') + cls._get_logger().warning( + "Failed accessing the jupyter server{}: {}".format( + " [password={}]".format(password) if server_info.get("password") else "", ex + ) + ) + return os.path.join(os.getcwd(), "error_notebook_not_found.py") notebooks = r.json() cur_notebook = None for n in notebooks: - if n['kernel']['id'] == current_kernel: + if n["kernel"]["id"] == current_kernel: cur_notebook = n break @@ -725,8 +774,8 @@ class ScriptInfo(object): if not cur_notebook: continue - notebook_path = cur_notebook['notebook'].get('path', '') - notebook_name = cur_notebook['notebook'].get('name', '') + notebook_path = cur_notebook["notebook"].get("path", "") + notebook_name = cur_notebook["notebook"].get("name", "") if notebook_path: break @@ -746,15 +795,16 @@ class ScriptInfo(object): notebook_name = colab_name log_history = False - script_entry_point = str(notebook_name or 'notebook').replace( - '>', '_').replace('<', '_').replace('.ipynb', '.py') + script_entry_point = ( + str(notebook_name or "notebook").replace(">", "_").replace("<", "_").replace(".ipynb", ".py") + ) - if not script_entry_point.lower().endswith('.py'): - script_entry_point += '.py' + if not script_entry_point.lower().endswith(".py"): + script_entry_point += ".py" local_ipynb_file = None elif notebook_path is not None: # always slash, because this is from uri (so never backslash not even on windows) - entry_point_filename = notebook_path.split('/')[-1] + entry_point_filename = notebook_path.split("/")[-1] # now we should try to find the actual file entry_point = (Path.cwd() / entry_point_filename).absolute() @@ -765,7 +815,7 @@ class ScriptInfo(object): if not entry_point.exists(): # noinspection PyBroadException try: - alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb' + alternative_entry_point = "-".join(entry_point_filename.split("-")[:-5]) + ".ipynb" # now we should try to find the actual file entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute() if not entry_point_alternative.is_file(): @@ -775,25 +825,26 @@ class ScriptInfo(object): if entry_point_alternative.exists(): entry_point = entry_point_alternative except Exception as ex: - cls._get_logger().warning('Failed accessing jupyter notebook {}: {}'.format(notebook_path, ex)) + cls._get_logger().warning("Failed accessing jupyter notebook {}: {}".format(notebook_path, ex)) # if we failed to get something we can access, print an error if not entry_point.exists(): cls._get_logger().warning( - 'Jupyter Notebook auto-logging failed, could not access: {}'.format(entry_point)) - return 'error_notebook_not_found.py' + "Jupyter Notebook auto-logging failed, could not access: {}".format(entry_point) + ) + return "error_notebook_not_found.py" # get local ipynb for observer local_ipynb_file = entry_point.as_posix() # now replace the .ipynb with .py # we assume we will have that file available with the Jupyter notebook plugin - entry_point = entry_point.with_suffix('.py') + entry_point = entry_point.with_suffix(".py") script_entry_point = entry_point.as_posix() else: # we could not find and access any jupyter server - cls._get_logger().warning('Failed accessing the jupyter server(s): {}'.format(jupyter_servers)) + cls._get_logger().warning("Failed accessing the jupyter server(s): {}".format(jupyter_servers)) return None # 'error_notebook_not_found.py' # install the post store hook, @@ -819,8 +870,7 @@ class ScriptInfo(object): notebook_data = json.load(f) client = boto3.client("sagemaker") response = client.create_presigned_domain_url( - DomainId=notebook_data["DomainId"], - UserProfileName=notebook_data["UserProfileName"] + DomainId=notebook_data["DomainId"], UserProfileName=notebook_data["UserProfileName"] ) authorized_url = response["AuthorizedUrl"] authorized_url_parsed = urlparse(authorized_url) @@ -842,7 +892,7 @@ class ScriptInfo(object): try: from google.colab import _message # noqa - notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb'] + notebook = _message.blocking_request("get_ipynb", timeout_sec=timeout)["ipynb"] notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb") if not notebook_name.endswith(".ipynb"): notebook_name += ".ipynb" @@ -863,8 +913,8 @@ class ScriptInfo(object): try: # Use os.path.relpath as it calculates up dir movements (../) entry_point = os.path.relpath( - str(os.path.realpath(script_path.as_posix())), - str(cls._get_working_dir(repo_root, return_abs=True))) + str(os.path.realpath(script_path.as_posix())), str(cls._get_working_dir(repo_root, return_abs=True)) + ) except ValueError: # Working directory not under repository root entry_point = script_path.relative_to(repo_root) @@ -875,11 +925,12 @@ class ScriptInfo(object): def _cwd(cls): # return the current working directory (solve for hydra changing it) # check if running with hydra - if sys.modules.get('hydra'): + if sys.modules.get("hydra"): # noinspection PyBroadException try: # noinspection PyPackageRequirements import hydra # noqa + return Path(hydra.utils.get_original_cwd()).absolute() except Exception: pass @@ -900,7 +951,7 @@ class ScriptInfo(object): return cwd.as_posix() if return_abs else relative except ValueError: # Working directory not under repository root, default to repo root - return repo_root.as_posix() if return_abs else '.' + return repo_root.as_posix() if return_abs else "." @classmethod def _absolute_path(cls, file_path, cwd): @@ -924,18 +975,26 @@ class ScriptInfo(object): # noinspection PyBroadException try: - with open(script_path, 'r', encoding='utf-8') as f: + with open(script_path, "r", encoding="utf-8") as f: script_code = f.read() return script_code except Exception: pass - return '' + return "" @classmethod def _get_script_info( - cls, filepaths, check_uncommitted=True, create_requirements=True, log=None, - uncommitted_from_remote=False, detect_jupyter_notebook=True, - add_missing_installed_packages=False, detailed_req_report=None, force_single_script=False): + cls, + filepaths, + check_uncommitted=True, + create_requirements=True, + log=None, + uncommitted_from_remote=False, + detect_jupyter_notebook=True, + add_missing_installed_packages=False, + detailed_req_report=None, + force_single_script=False, + ): jupyter_filepath = cls._get_jupyter_notebook_filename() if detect_jupyter_notebook else None if jupyter_filepath: scripts_path = [Path(os.path.normpath(jupyter_filepath)).absolute()] @@ -944,24 +1003,18 @@ class ScriptInfo(object): scripts_path = [Path(cls._absolute_path(os.path.normpath(f), cwd)) for f in filepaths if f] scripts_path = [f for f in scripts_path if f.exists()] if not scripts_path: - for f in (filepaths or []): + for f in filepaths or []: if f and f.endswith("/"): raise ScriptInfoError("python console detected") - raise ScriptInfoError( - "Script file {} could not be found".format(filepaths) - ) + raise ScriptInfoError("Script file {} could not be found".format(filepaths)) scripts_dir = [f.parent for f in scripts_path] def _log(msg, *args, **kwargs): if not log: return - log.warning( - "Failed auto-detecting task repository: {}".format( - msg.format(*args, **kwargs) - ) - ) + log.warning("Failed auto-detecting task repository: {}".format(msg.format(*args, **kwargs))) script_dir = scripts_dir[0] script_path = scripts_path[0] @@ -987,7 +1040,8 @@ class ScriptInfo(object): if plugin: try: repo_info = plugin.get_info( - str(script_dir), include_diff=check_uncommitted, diff_from_remote=uncommitted_from_remote) + str(script_dir), include_diff=check_uncommitted, diff_from_remote=uncommitted_from_remote + ) except SystemExit: raise except Exception as ex: @@ -998,7 +1052,7 @@ class ScriptInfo(object): repo_root = repo_info.root or script_dir if not plugin: - working_dir = VCS_WORK_DIR.get() or '.' + working_dir = VCS_WORK_DIR.get() or "." entry_point = VCS_ENTRY_POINT.get() or str(script_path.name) else: # allow to override the VCS working directory (notice relative to the git repo) @@ -1016,8 +1070,11 @@ class ScriptInfo(object): if jupyter_filepath: diff = cls._get_script_code(script_path.as_posix()) else: - diff = cls._get_script_code(script_path.as_posix()) \ - if not plugin or not repo_info.commit else repo_info.diff + diff = ( + cls._get_script_code(script_path.as_posix()) + if not plugin or not repo_info.commit + else repo_info.diff + ) if VCS_DIFF.exists(): diff = VCS_DIFF.get() or "" @@ -1026,27 +1083,32 @@ class ScriptInfo(object): if len(diff) > cls.max_diff_size_bytes: messages.append( "======> WARNING! Git diff too large to store " - "({}kb), skipping uncommitted changes <======".format(len(diff) // 1024)) + "({}kb), skipping uncommitted changes <======".format(len(diff) // 1024) + ) auxiliary_git_diff = diff - diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \ - '# full git diff available in Artifacts/auxiliary_git_diff\n' \ - '# Clear the section before enqueueing Task!\n' + diff = ( + "# WARNING! git diff too large to store, clear this section to execute without it.\n" + "# full git diff available in Artifacts/auxiliary_git_diff\n" + "# Clear the section before enqueueing Task!\n" + ) else: - diff = '' + diff = "" # if this is not jupyter, get the requirements.txt - requirements = '' - conda_requirements = '' + requirements = "" + conda_requirements = "" # create requirements if backend supports requirements # if jupyter is present, requirements will be created in the background, when saving a snapshot - if not jupyter_filepath and Session.check_min_api_version('2.2'): + if not jupyter_filepath and Session.check_min_api_version("2.2"): script_requirements = ScriptRequirements( - Path(repo_root).as_posix() if repo_info.url else script_path.as_posix()) + Path(repo_root).as_posix() if repo_info.url else script_path.as_posix() + ) if create_requirements: requirements, conda_requirements = script_requirements.get_requirements( entry_point_filename=script_path.as_posix() - if not repo_info.url and script_path.is_file() else None, + if not repo_info.url and script_path.is_file() + else None, add_missing_installed_packages=add_missing_installed_packages, detailed_req_report=detailed_req_report, ) @@ -1062,8 +1124,8 @@ class ScriptInfo(object): working_dir=working_dir, diff=diff, ide=ide, - requirements={'pip': requirements, 'conda': conda_requirements} if requirements else None, - binary='python{}.{}'.format(sys.version_info.major, sys.version_info.minor), + requirements={"pip": requirements, "conda": conda_requirements} if requirements else None, + binary="python{}.{}".format(sys.version_info.major, sys.version_info.minor), repo_root=repo_root, jupyter_filepath=jupyter_filepath, ) @@ -1078,8 +1140,10 @@ class ScriptInfo(object): if not any(script_info.values()): script_info = None - return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff), - script_requirements) + return ( + ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff), + script_requirements, + ) @classmethod def _detect_distributed_execution(cls, entry_point, log): @@ -1094,6 +1158,7 @@ class ScriptInfo(object): # noinspection PyBroadException try: from psutil import Process # noqa + cmdline = Process().parent().cmdline() # first find the torch model call "torch.distributed.run" or "torch.distributed.launch" if is_torch_distributed: @@ -1118,12 +1183,16 @@ class ScriptInfo(object): log.info( "{} execution detected: adjusting entrypoint to " "reflect distributed execution arguments".format( - "Torch Distributed" if is_torch_distributed else "Transformers Accelerate") + "Torch Distributed" if is_torch_distributed else "Transformers Accelerate" + ) ) except Exception: if log: - log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format( - "Torch Distributed" if is_torch_distributed else "Transformers Accelerate")) + log.warning( + "{} execution detected: Failed Detecting launch arguments, skipping".format( + "Torch Distributed" if is_torch_distributed else "Transformers Accelerate" + ) + ) return entry_point @@ -1137,10 +1206,11 @@ class ScriptInfo(object): import ipykernel from glob import glob import json - for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), '??server-*.json')): + + for f in glob(os.path.join(os.path.dirname(ipykernel.get_connection_file()), "??server-*.json")): # noinspection PyBroadException try: - with open(f, 'r') as json_data: + with open(f, "r") as json_data: server_info = json.load(json_data) except Exception: continue @@ -1151,16 +1221,28 @@ class ScriptInfo(object): return None @classmethod - def get(cls, filepaths=None, check_uncommitted=True, create_requirements=True, log=None, - uncommitted_from_remote=False, detect_jupyter_notebook=True, add_missing_installed_packages=False, - detailed_req_report=None, force_single_script=False): + def get( + cls, + filepaths=None, + check_uncommitted=True, + create_requirements=True, + log=None, + uncommitted_from_remote=False, + detect_jupyter_notebook=True, + add_missing_installed_packages=False, + detailed_req_report=None, + force_single_script=False + ): try: if not filepaths: - filepaths = [sys.argv[0], ] + filepaths = [ + sys.argv[0], + ] return cls._get_script_info( filepaths=filepaths, check_uncommitted=check_uncommitted, - create_requirements=create_requirements, log=log, + create_requirements=create_requirements, + log=log, uncommitted_from_remote=uncommitted_from_remote, detect_jupyter_notebook=detect_jupyter_notebook, add_missing_installed_packages=add_missing_installed_packages, @@ -1178,7 +1260,7 @@ class ScriptInfo(object): def is_running_from_module(cls): # noinspection PyBroadException try: - return '__main__' in sys.modules and vars(sys.modules['__main__'])['__package__'] + return "__main__" in sys.modules and vars(sys.modules["__main__"])["__package__"] except Exception: return False @@ -1189,30 +1271,29 @@ class ScriptInfo(object): # noinspection PyBroadException try: # If this is jupyter, do not try to detect the running module, we know what we have. - if script_dict.get('jupyter_filepath'): + if script_dict.get("jupyter_filepath"): return script_dict if cls.is_running_from_module(): - argvs = '' - git_root = os.path.abspath(str(script_dict['repo_root'])) if script_dict['repo_root'] else None + argvs = "" + git_root = os.path.abspath(str(script_dict["repo_root"])) if script_dict["repo_root"] else None for a in sys.argv[1:]: if git_root and os.path.exists(a): # check if common to project: a_abs = os.path.abspath(a) if os.path.commonpath([a_abs, git_root]) == git_root: # adjust path relative to working dir inside git repo - a = ' ' + os.path.relpath( - a_abs, os.path.join(git_root, str(script_dict['working_dir']))) - argvs += ' {}'.format(a) + a = " " + os.path.relpath(a_abs, os.path.join(git_root, str(script_dict["working_dir"]))) + argvs += " {}".format(a) # noinspection PyBroadException try: - module_name = vars(sys.modules['__main__'])['__spec__'].name + module_name = vars(sys.modules["__main__"])["__spec__"].name except Exception: - module_name = vars(sys.modules['__main__'])['__package__'] + module_name = vars(sys.modules["__main__"])["__package__"] # update the script entry point to match the real argv and module call - script_dict['entry_point'] = '-m {}{}'.format(module_name, (' ' + argvs) if argvs else '') + script_dict["entry_point"] = "-m {}{}".format(module_name, (" " + argvs) if argvs else "") except Exception: pass return script_dict @@ -1220,13 +1301,14 @@ class ScriptInfo(object): @staticmethod def is_google_colab(): # type: () -> bool - """ Know if the script is running from Google Colab """ + """Know if the script is running from Google Colab""" # noinspection PyBroadException try: # noinspection PyPackageRequirements from IPython import get_ipython - if get_ipython() and 'google.colab' in get_ipython().extension_manager.loaded: + + if get_ipython() and "google.colab" in get_ipython().extension_manager.loaded: return True except Exception: pass @@ -1235,7 +1317,7 @@ class ScriptInfo(object): @staticmethod def is_vscode(): # type: () -> bool - """ Know if the script is running from VSCode """ + """Know if the script is running from VSCode""" if os.environ.get("TERM_PROGRAM") == "vscode": return True @@ -1247,7 +1329,7 @@ class ScriptInfo(object): @staticmethod def is_pycharm(): # type: () -> bool - """ Know if the script is running from PyCharm """ + """Know if the script is running from PyCharm""" # youtrack.jetbrains.com ISSUE #PY-4853 added this variables if os.environ.get("PYCHARM_HOSTED"): @@ -1259,7 +1341,7 @@ class ScriptInfo(object): @staticmethod def is_jupyter(): # type: () -> bool - """ Know if the script is running from Jupyter """ + """Know if the script is running from Jupyter""" if isinstance(ScriptInfo._get_jupyter_notebook_filename(), str): return True @@ -1301,9 +1383,9 @@ class ScriptInfoResult(object): class _JupyterHistoryLogger(object): - _reg_replace_ipython = r'\n([ \t]*)get_ipython\(\)' - _reg_replace_magic = r'\n([ \t]*)%' - _reg_replace_bang = r'\n([ \t]*)!' + _reg_replace_ipython = r"\n([ \t]*)get_ipython\(\)" + _reg_replace_magic = r"\n([ \t]*)%" + _reg_replace_bang = r"\n([ \t]*)!" def __init__(self): self._exception_raised = False @@ -1314,6 +1396,7 @@ class _JupyterHistoryLogger(object): # noinspection PyBroadException try: import re + self._replace_ipython_pattern = re.compile(self._reg_replace_ipython) self._replace_magic_pattern = re.compile(self._reg_replace_magic) self._replace_bang_pattern = re.compile(self._reg_replace_bang) @@ -1337,7 +1420,7 @@ class _JupyterHistoryLogger(object): # noinspection PyBroadException try: # if this is colab, the callbacks do not contain the raw_cell content, so we have to patch it - if 'google.colab' in self._ip.extension_manager.loaded: + if "google.colab" in self._ip.extension_manager.loaded: self._ip._org_run_cell = self._ip.run_cell self._ip.run_cell = partial(self._patched_run_cell, self._ip) except Exception: @@ -1345,14 +1428,14 @@ class _JupyterHistoryLogger(object): # start with the current history self._initialize_history() - self._ip.events.register('post_run_cell', self._post_cell_callback) - self._ip.events.register('pre_run_cell', self._pre_cell_callback) + self._ip.events.register("post_run_cell", self._post_cell_callback) + self._ip.events.register("pre_run_cell", self._pre_cell_callback) self._ip.set_custom_exc((Exception,), self._exception_callback) def _patched_run_cell(self, shell, *args, **kwargs): # noinspection PyBroadException try: - raw_cell = kwargs.get('raw_cell') or args[0] + raw_cell = kwargs.get("raw_cell") or args[0] self._current_cell = raw_cell except Exception: pass @@ -1360,13 +1443,13 @@ class _JupyterHistoryLogger(object): return shell._org_run_cell(*args, **kwargs) def history(self, filename): - with open(filename, 'wt') as f: + with open(filename, "wt") as f: for k, v in sorted(self._cells_code.items(), key=lambda p: p[0]): f.write(v) def history_to_str(self): # return a pair: (history as str, current cell if we are in still in cell execution otherwise None) - return '\n'.join(v for k, v in sorted(self._cells_code.items(), key=lambda p: p[0])), self._current_cell + return "\n".join(v for k, v in sorted(self._cells_code.items(), key=lambda p: p[0])), self._current_cell # noinspection PyUnusedLocal def _exception_callback(self, shell, etype, value, tb, tb_offset=None): @@ -1397,7 +1480,7 @@ class _JupyterHistoryLogger(object): # add the cell history # noinspection PyBroadException try: - cell_code = '\n' + self._ip.history_manager.input_hist_parsed[-1] + cell_code = "\n" + self._ip.history_manager.input_hist_parsed[-1] except Exception: return @@ -1415,7 +1498,7 @@ class _JupyterHistoryLogger(object): return # noinspection PyBroadException try: - cell_code = '\n' + '\n'.join(self._ip.history_manager.input_hist_parsed[:-1]) + cell_code = "\n" + "\n".join(self._ip.history_manager.input_hist_parsed[:-1]) except Exception: return @@ -1425,8 +1508,8 @@ class _JupyterHistoryLogger(object): def _conform_code(self, cell_code, replace_magic_bang=False): # fix magic / bang in code if self._replace_ipython_pattern: - cell_code = self._replace_ipython_pattern.sub(r'\n# \g<1>get_ipython()', cell_code) + cell_code = self._replace_ipython_pattern.sub(r"\n# \g<1>get_ipython()", cell_code) if replace_magic_bang and self._replace_magic_pattern and self._replace_bang_pattern: - cell_code = self._replace_magic_pattern.sub(r'\n# \g<1>%', cell_code) - cell_code = self._replace_bang_pattern.sub(r'\n# \g<1>!', cell_code) + cell_code = self._replace_magic_pattern.sub(r"\n# \g<1>%", cell_code) + cell_code = self._replace_bang_pattern.sub(r"\n# \g<1>!", cell_code) return cell_code