Allow for automatic identification of SageMaker notebooks

This commit is contained in:
Alex Burlacu 2023-03-23 18:24:46 +02:00
parent 4bb83c1c6c
commit a5d25b1a88

View File

@ -4,10 +4,12 @@ from copy import copy
from datetime import datetime
from functools import partial
from tempfile import gettempdir, mkdtemp
from urllib.parse import urlparse
import attr
import logging
import json
import requests
from pathlib2 import Path
from threading import Thread
@ -536,6 +538,8 @@ class _JupyterObserver(object):
class ScriptInfo(object):
_sagemaker_metadata_path = "/opt/ml/metadata/resource-metadata.json"
max_diff_size_bytes = 500000
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
@ -623,7 +627,6 @@ class ScriptInfo(object):
except Exception:
pass
import requests
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
notebook_path = None
@ -673,7 +676,7 @@ class ScriptInfo(object):
r.raise_for_status()
except Exception as ex:
# raise on last one only
if server_index == len(jupyter_servers)-1:
if server_index == len(jupyter_servers) - 1:
cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format(
' [password={}]'.format(password) if server_info.get('password') else '', ex))
return os.path.join(os.getcwd(), 'error_notebook_not_found.py')
@ -694,6 +697,9 @@ class ScriptInfo(object):
if notebook_path:
break
if (not notebook_name or not notebook_path) and ScriptInfo.is_sagemaker():
notebook_path, notebook_name = ScriptInfo._get_sagemaker_notebook(current_kernel)
is_google_colab = False
log_history = False
colab_name = None
@ -765,6 +771,37 @@ class ScriptInfo(object):
except Exception:
return None
@classmethod
def is_sagemaker(cls):
return Path(cls._sagemaker_metadata_path).is_file()
@classmethod
def _get_sagemaker_notebook(cls, current_kernel, timeout=30):
# noinspection PyBroadException
try:
# we expect to find boto3 in the sagemaker env
import boto3
with open(cls._sagemaker_metadata_path) as f:
notebook_data = json.load(f)
client = boto3.client("sagemaker")
response = client.create_presigned_domain_url(
DomainId=notebook_data["DomainId"],
UserProfileName=notebook_data["UserProfileName"]
)
authorized_url = response["AuthorizedUrl"]
authorized_url_parsed = urlparse(authorized_url)
unauthorized_url = authorized_url_parsed.scheme + "://" + authorized_url_parsed.netloc
with requests.Session() as s:
s.get(authorized_url, timeout=timeout)
jupyter_sessions = s.get(unauthorized_url + "/jupyter/default/api/sessions", timeout=timeout).json()
for jupyter_session in jupyter_sessions:
if jupyter_session.get("kernel", {}).get("id") == current_kernel:
return jupyter_session.get("path", ""), jupyter_session.get("name", "")
except Exception as e:
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
return None, None
@classmethod
def _get_colab_notebook(cls, timeout=30):
# returns tuple (notebook name, raw string notebook)