mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Allow for automatic identification of SageMaker notebooks
This commit is contained in:
parent
4bb83c1c6c
commit
a5d25b1a88
@ -4,10 +4,12 @@ from copy import copy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tempfile import gettempdir, mkdtemp
|
from tempfile import gettempdir, mkdtemp
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
import requests
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
@ -536,6 +538,8 @@ class _JupyterObserver(object):
|
|||||||
|
|
||||||
|
|
||||||
class ScriptInfo(object):
|
class ScriptInfo(object):
|
||||||
|
_sagemaker_metadata_path = "/opt/ml/metadata/resource-metadata.json"
|
||||||
|
|
||||||
max_diff_size_bytes = 500000
|
max_diff_size_bytes = 500000
|
||||||
|
|
||||||
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
|
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
|
||||||
@ -623,7 +627,6 @@ class ScriptInfo(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import requests
|
|
||||||
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
|
current_kernel = sys.argv[2].split(os.path.sep)[-1].replace('kernel-', '').replace('.json', '')
|
||||||
|
|
||||||
notebook_path = None
|
notebook_path = None
|
||||||
@ -673,7 +676,7 @@ class ScriptInfo(object):
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
# raise on last one only
|
# raise on last one only
|
||||||
if server_index == len(jupyter_servers)-1:
|
if server_index == len(jupyter_servers) - 1:
|
||||||
cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format(
|
cls._get_logger().warning('Failed accessing the jupyter server{}: {}'.format(
|
||||||
' [password={}]'.format(password) if server_info.get('password') else '', ex))
|
' [password={}]'.format(password) if server_info.get('password') else '', ex))
|
||||||
return os.path.join(os.getcwd(), 'error_notebook_not_found.py')
|
return os.path.join(os.getcwd(), 'error_notebook_not_found.py')
|
||||||
@ -694,6 +697,9 @@ class ScriptInfo(object):
|
|||||||
if notebook_path:
|
if notebook_path:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if (not notebook_name or not notebook_path) and ScriptInfo.is_sagemaker():
|
||||||
|
notebook_path, notebook_name = ScriptInfo._get_sagemaker_notebook(current_kernel)
|
||||||
|
|
||||||
is_google_colab = False
|
is_google_colab = False
|
||||||
log_history = False
|
log_history = False
|
||||||
colab_name = None
|
colab_name = None
|
||||||
@ -765,6 +771,37 @@ class ScriptInfo(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_sagemaker(cls):
|
||||||
|
return Path(cls._sagemaker_metadata_path).is_file()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_sagemaker_notebook(cls, current_kernel, timeout=30):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# we expect to find boto3 in the sagemaker env
|
||||||
|
import boto3
|
||||||
|
with open(cls._sagemaker_metadata_path) as f:
|
||||||
|
notebook_data = json.load(f)
|
||||||
|
client = boto3.client("sagemaker")
|
||||||
|
response = client.create_presigned_domain_url(
|
||||||
|
DomainId=notebook_data["DomainId"],
|
||||||
|
UserProfileName=notebook_data["UserProfileName"]
|
||||||
|
)
|
||||||
|
authorized_url = response["AuthorizedUrl"]
|
||||||
|
authorized_url_parsed = urlparse(authorized_url)
|
||||||
|
unauthorized_url = authorized_url_parsed.scheme + "://" + authorized_url_parsed.netloc
|
||||||
|
with requests.Session() as s:
|
||||||
|
s.get(authorized_url, timeout=timeout)
|
||||||
|
jupyter_sessions = s.get(unauthorized_url + "/jupyter/default/api/sessions", timeout=timeout).json()
|
||||||
|
for jupyter_session in jupyter_sessions:
|
||||||
|
if jupyter_session.get("kernel", {}).get("id") == current_kernel:
|
||||||
|
return jupyter_session.get("path", ""), jupyter_session.get("name", "")
|
||||||
|
except Exception as e:
|
||||||
|
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_colab_notebook(cls, timeout=30):
|
def _get_colab_notebook(cls, timeout=30):
|
||||||
# returns tuple (notebook name, raw string notebook)
|
# returns tuple (notebook name, raw string notebook)
|
||||||
|
Loading…
Reference in New Issue
Block a user