Allow for automatic identification of SageMaker notebooks

This commit is contained in:
Alex Burlacu 2023-03-23 18:24:46 +02:00
parent 4bb83c1c6c
commit a5d25b1a88

View File

@ -4,10 +4,12 @@ from copy import copy
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from tempfile import gettempdir, mkdtemp from tempfile import gettempdir, mkdtemp
from urllib.parse import urlparse
import attr import attr
import logging import logging
import json import json
import requests
from pathlib2 import Path from pathlib2 import Path
from threading import Thread from threading import Thread
@ -536,6 +538,8 @@ class _JupyterObserver(object):
class ScriptInfo(object): class ScriptInfo(object):
_sagemaker_metadata_path = "/opt/ml/metadata/resource-metadata.json"
max_diff_size_bytes = 500000 max_diff_size_bytes = 500000
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()] plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
@ -623,7 +627,6 @@ class ScriptInfo(object):
except Exception: except Exception:
pass pass
import requests
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '') current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
notebook_path = None notebook_path = None
@ -673,7 +676,7 @@ class ScriptInfo(object):
r.raise_for_status() r.raise_for_status()
except Exception as ex: except Exception as ex:
# raise on last one only # raise on last one only
if server_index == len(jupyter_servers)-1: if server_index == len(jupyter_servers) - 1:
cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format( cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format(
' [password={}]'.format(password) if server_info.get('password') else '', ex)) ' [password={}]'.format(password) if server_info.get('password') else '', ex))
return os.path.join(os.getcwd(), 'error_notebook_not_found.py') return os.path.join(os.getcwd(), 'error_notebook_not_found.py')
@ -694,6 +697,9 @@ class ScriptInfo(object):
if notebook_path: if notebook_path:
break break
if (not notebook_name or not notebook_path) and ScriptInfo.is_sagemaker():
notebook_path, notebook_name = ScriptInfo._get_sagemaker_notebook(current_kernel)
is_google_colab = False is_google_colab = False
log_history = False log_history = False
colab_name = None colab_name = None
@ -765,6 +771,37 @@ class ScriptInfo(object):
except Exception: except Exception:
return None return None
@classmethod
def is_sagemaker(cls):
return Path(cls._sagemaker_metadata_path).is_file()
@classmethod
def _get_sagemaker_notebook(cls, current_kernel, timeout=30):
# noinspection PyBroadException
try:
# we expect to find boto3 in the sagemaker env
import boto3
with open(cls._sagemaker_metadata_path) as f:
notebook_data = json.load(f)
client = boto3.client("sagemaker")
response = client.create_presigned_domain_url(
DomainId=notebook_data["DomainId"],
UserProfileName=notebook_data["UserProfileName"]
)
authorized_url = response["AuthorizedUrl"]
authorized_url_parsed = urlparse(authorized_url)
unauthorized_url = authorized_url_parsed.scheme + "://" + authorized_url_parsed.netloc
with requests.Session() as s:
s.get(authorized_url, timeout=timeout)
jupyter_sessions = s.get(unauthorized_url + "/jupyter/default/api/sessions", timeout=timeout).json()
for jupyter_session in jupyter_sessions:
if jupyter_session.get("kernel", {}).get("id") == current_kernel:
return jupyter_session.get("path", ""), jupyter_session.get("name", "")
except Exception as e:
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
return None, None
@classmethod @classmethod
def _get_colab_notebook(cls, timeout=30): def _get_colab_notebook(cls, timeout=30):
# returns tuple (notebook name, raw string notebook) # returns tuple (notebook name, raw string notebook)