diff --git a/clearml/backend_interface/task/repo/scriptinfo.py b/clearml/backend_interface/task/repo/scriptinfo.py index 737bb728..7a674d66 100644 --- a/clearml/backend_interface/task/repo/scriptinfo.py +++ b/clearml/backend_interface/task/repo/scriptinfo.py @@ -4,10 +4,12 @@ from copy import copy from datetime import datetime from functools import partial from tempfile import gettempdir, mkdtemp +from urllib.parse import urlparse import attr import logging import json +import requests from pathlib2 import Path from threading import Thread @@ -536,6 +538,8 @@ class _JupyterObserver(object): class ScriptInfo(object): + _sagemaker_metadata_path = "/opt/ml/metadata/resource-metadata.json" + max_diff_size_bytes = 500000 plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()] @@ -623,7 +627,6 @@ class ScriptInfo(object): except Exception: pass - import requests current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '') notebook_path = None @@ -673,7 +676,7 @@ class ScriptInfo(object): r.raise_for_status() except Exception as ex: # raise on last one only - if server_index == len(jupyter_servers)-1: + if server_index == len(jupyter_servers) - 1: cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format( ' [password={}]'.format(password) if server_info.get('password') else '', ex)) return os.path.join(os.getcwd(), 'error_notebook_not_found.py') @@ -694,6 +697,9 @@ class ScriptInfo(object): if notebook_path: break + if (not notebook_name or not notebook_path) and ScriptInfo.is_sagemaker(): + notebook_path, notebook_name = ScriptInfo._get_sagemaker_notebook(current_kernel) + is_google_colab = False log_history = False colab_name = None @@ -765,6 +771,37 @@ class ScriptInfo(object): except Exception: return None + @classmethod + def is_sagemaker(cls): + return Path(cls._sagemaker_metadata_path).is_file() + + @classmethod + def _get_sagemaker_notebook(cls, current_kernel, timeout=30): + # noinspection PyBroadException + try: + # we expect to find boto3 in the sagemaker env + import boto3 + with open(cls._sagemaker_metadata_path) as f: + notebook_data = json.load(f) + client = boto3.client("sagemaker") + response = client.create_presigned_domain_url( + DomainId=notebook_data["DomainId"], + UserProfileName=notebook_data["UserProfileName"] + ) + authorized_url = response["AuthorizedUrl"] + authorized_url_parsed = urlparse(authorized_url) + unauthorized_url = authorized_url_parsed.scheme + "://" + authorized_url_parsed.netloc + with requests.Session() as s: + s.get(authorized_url, timeout=timeout) + jupyter_sessions = s.get(unauthorized_url + "/jupyter/default/api/sessions", timeout=timeout).json() + for jupyter_session in jupyter_sessions: + if jupyter_session.get("kernel", {}).get("id") == current_kernel: + return jupyter_session.get("path", ""), jupyter_session.get("name", "") + except Exception as e: + cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e)) + + return None, None + @classmethod def _get_colab_notebook(cls, timeout=30): # returns tuple (notebook name, raw string notebook)