mirror of
https://github.com/clearml/clearml
synced 2025-04-06 13:45:17 +00:00
Add Task.set_credentials for cloud hosted jupyter support
This commit is contained in:
parent
cac4ac12b8
commit
7d0bf4838e
@ -1,14 +1,15 @@
|
|||||||
{
|
{
|
||||||
version: 1.5
|
version: 1.5
|
||||||
host: https://demoapi.trainsai.io
|
# default https://demoapi.trainsai.io host
|
||||||
|
host: ""
|
||||||
|
|
||||||
# verify host ssl certificate, set to False only if you have a very good reason
|
# verify host ssl certificate, set to False only if you have a very good reason
|
||||||
verify_certificate: True
|
verify_certificate: True
|
||||||
|
|
||||||
# default demoapi.trainsai.io credentials
|
# default demoapi.trainsai.io credentials
|
||||||
credentials {
|
credentials {
|
||||||
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
access_key: ""
|
||||||
secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
secret_key: ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# default version assigned to requests with no specific version. this is not expected to change
|
# default version assigned to requests with no specific version. this is not expected to change
|
||||||
|
@ -3,9 +3,9 @@ import sys
|
|||||||
import types
|
import types
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
|
||||||
|
import jwt
|
||||||
import requests
|
import requests
|
||||||
import six
|
import six
|
||||||
import jwt
|
|
||||||
from pyhocon import ConfigTree
|
from pyhocon import ConfigTree
|
||||||
from requests.auth import HTTPBasicAuth
|
from requests.auth import HTTPBasicAuth
|
||||||
|
|
||||||
@ -36,6 +36,9 @@ class Session(TokenManager):
|
|||||||
_session_timeout = (5.0, None)
|
_session_timeout = (5.0, None)
|
||||||
|
|
||||||
api_version = '2.1'
|
api_version = '2.1'
|
||||||
|
default_host = "https://demoapi.trainsai.io"
|
||||||
|
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||||
|
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||||
|
|
||||||
# TODO: add requests.codes.gateway_timeout once we support async commits
|
# TODO: add requests.codes.gateway_timeout once we support async commits
|
||||||
_retry_codes = [
|
_retry_codes = [
|
||||||
@ -94,7 +97,7 @@ class Session(TokenManager):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||||
default=self.config.get("api.credentials.access_key", None)
|
default=(self.config.get("api.credentials.access_key") or self.default_key)
|
||||||
)
|
)
|
||||||
if not self.access_key:
|
if not self.access_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -102,14 +105,14 @@ class Session(TokenManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
||||||
default=self.config.get("api.credentials.secret_key", None)
|
default=(self.config.get("api.credentials.secret_key") or self.default_secret)
|
||||||
)
|
)
|
||||||
if not self.secret_key:
|
if not self.secret_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing secret_key. Please set in configuration file or pass in session init."
|
"Missing secret_key. Please set in configuration file or pass in session init."
|
||||||
)
|
)
|
||||||
|
|
||||||
host = host or ENV_HOST.get(default=self.config.get("api.host"))
|
host = host or self.get_api_server_host(config=self.config)
|
||||||
if not host:
|
if not host:
|
||||||
raise ValueError("host is required in init or config")
|
raise ValueError("host is required in init or config")
|
||||||
|
|
||||||
@ -386,6 +389,13 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
return call_result
|
return call_result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_api_server_host(cls, config=None):
|
||||||
|
if not config:
|
||||||
|
from ...config import config_obj
|
||||||
|
config = config_obj
|
||||||
|
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
|
||||||
|
|
||||||
def _do_refresh_token(self, old_token, exp=None):
|
def _do_refresh_token(self, old_token, exp=None):
|
||||||
""" TokenManager abstract method implementation.
|
""" TokenManager abstract method implementation.
|
||||||
Here we ignore the old token and simply obtain a new token.
|
Here we ignore the old token and simply obtain a new token.
|
||||||
|
@ -4,9 +4,10 @@ import requests.exceptions
|
|||||||
import six
|
import six
|
||||||
from ..backend_api import Session
|
from ..backend_api import Session
|
||||||
from ..backend_api.session import BatchRequest
|
from ..backend_api.session import BatchRequest
|
||||||
|
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||||
|
|
||||||
from ..config import config_obj
|
from ..config import config_obj
|
||||||
from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY
|
from ..config.defs import LOG_LEVEL_ENV_VAR
|
||||||
from ..debugging import get_logger
|
from ..debugging import get_logger
|
||||||
from ..backend_api.version import __version__
|
from ..backend_api.version import __version__
|
||||||
from .session import SendError, SessionInterface
|
from .session import SendError, SessionInterface
|
||||||
@ -78,8 +79,8 @@ class InterfaceBase(SessionInterface):
|
|||||||
initialize_logging=False,
|
initialize_logging=False,
|
||||||
client='sdk-%s' % __version__,
|
client='sdk-%s' % __version__,
|
||||||
config=config_obj,
|
config=config_obj,
|
||||||
api_key=API_ACCESS_KEY.get(),
|
api_key=ENV_ACCESS_KEY.get(),
|
||||||
secret_key=API_SECRET_KEY.get(),
|
secret_key=ENV_SECRET_KEY.get(),
|
||||||
)
|
)
|
||||||
return InterfaceBase._default_session
|
return InterfaceBase._default_session
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ from six.moves.urllib.parse import urlparse, urlunparse
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from ...backend_api.session.defs import ENV_HOST
|
|
||||||
from ...backend_interface.task.development.worker import DevWorker
|
from ...backend_interface.task.development.worker import DevWorker
|
||||||
from ...backend_api import Session
|
from ...backend_api import Session
|
||||||
from ...backend_api.services import tasks, models, events, projects
|
from ...backend_api.services import tasks, models, events, projects
|
||||||
@ -23,7 +22,7 @@ from ..setupuploadmixin import SetupUploadMixin
|
|||||||
from ..util import make_message, get_or_create_project, get_single_result, \
|
from ..util import make_message, get_or_create_project, get_single_result, \
|
||||||
exact_match_regex
|
exact_match_regex
|
||||||
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
|
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
|
||||||
running_remotely, get_cache_dir, config_obj
|
running_remotely, get_cache_dir
|
||||||
from ...debugging import get_logger
|
from ...debugging import get_logger
|
||||||
from ...debugging.log import LoggerRoot
|
from ...debugging.log import LoggerRoot
|
||||||
from ...storage import StorageHelper
|
from ...storage import StorageHelper
|
||||||
@ -205,8 +204,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
# overwrite it before we have a chance to call edit)
|
# overwrite it before we have a chance to call edit)
|
||||||
self._edit(script=result.script)
|
self._edit(script=result.script)
|
||||||
self.reload()
|
self.reload()
|
||||||
if result.script.get('requirements'):
|
self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '')
|
||||||
self._update_requirements(result.script.get('requirements'))
|
|
||||||
check_package_update_thread.join()
|
check_package_update_thread.join()
|
||||||
|
|
||||||
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
|
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
|
||||||
@ -673,28 +671,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
app_host = self._get_app_server()
|
app_host = self._get_app_server()
|
||||||
parsed = urlparse(app_host)
|
parsed = urlparse(app_host)
|
||||||
if parsed.port:
|
if parsed.port:
|
||||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081'))
|
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
||||||
elif parsed.netloc.startswith('demoapp.'):
|
elif parsed.netloc.startswith('demoapp.'):
|
||||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.'))
|
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
||||||
|
elif parsed.netloc.startswith('app.'):
|
||||||
|
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
||||||
else:
|
else:
|
||||||
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
||||||
return urlunparse(parsed)
|
return urlunparse(parsed)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_api_server(cls):
|
def _get_api_server(cls):
|
||||||
return ENV_HOST.get(default=config_obj.get("api.host"))
|
return Session.get_api_server_host()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_app_server(cls):
|
def _get_app_server(cls):
|
||||||
host = cls._get_api_server()
|
host = cls._get_api_server()
|
||||||
if '://demoapi.' in host:
|
if '://demoapi.' in host:
|
||||||
return host.replace('://demoapi.', '://demoapp.')
|
return host.replace('://demoapi.', '://demoapp.', 1)
|
||||||
if '://api.' in host:
|
if '://api.' in host:
|
||||||
return host.replace('://api.', '://app.')
|
return host.replace('://api.', '://app.', 1)
|
||||||
|
|
||||||
parsed = urlparse(host)
|
parsed = urlparse(host)
|
||||||
if parsed.port == 8008:
|
if parsed.port == 8008:
|
||||||
return host.replace(':8008', ':8080')
|
return host.replace(':8008', ':8080', 1)
|
||||||
|
|
||||||
def _edit(self, **kwargs):
|
def _edit(self, **kwargs):
|
||||||
with self._edit_lock:
|
with self._edit_lock:
|
||||||
@ -709,8 +709,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
def _update_requirements(self, requirements):
|
def _update_requirements(self, requirements):
|
||||||
if not isinstance(requirements, dict):
|
if not isinstance(requirements, dict):
|
||||||
requirements = {'pip': requirements}
|
requirements = {'pip': requirements}
|
||||||
self.data.script.requirements = requirements
|
# protection, Old API might not support it
|
||||||
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
|
try:
|
||||||
|
self.data.script.requirements = requirements
|
||||||
|
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _update_script(self, script):
|
def _update_script(self, script):
|
||||||
self.data.script = script
|
self.data.script = script
|
||||||
|
@ -57,29 +57,29 @@ def main():
|
|||||||
if parsed_host.port == 8080:
|
if parsed_host.port == 8080:
|
||||||
# this is a docker 8080 is the web address, we need the api address, it is 8008
|
# this is a docker 8080 is the web address, we need the api address, it is 8008
|
||||||
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
|
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('demoapp.'):
|
elif parsed_host.netloc.startswith('demoapp.'):
|
||||||
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
|
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
|
||||||
parsed_host.netloc))
|
parsed_host.netloc))
|
||||||
# this is our demo server
|
# this is our demo server
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('app.'):
|
elif parsed_host.netloc.startswith('app.'):
|
||||||
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
|
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
|
||||||
parsed_host.netloc))
|
parsed_host.netloc))
|
||||||
# this is our application server
|
# this is our application server
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
elif parsed_host.port == 8008:
|
elif parsed_host.port == 8008:
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('demoapi.'):
|
elif parsed_host.netloc.startswith('demoapi.'):
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||||
elif parsed_host.netloc.startswith('api.'):
|
elif parsed_host.netloc.startswith('api.'):
|
||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||||
else:
|
else:
|
||||||
api_host = None
|
api_host = None
|
||||||
web_host = None
|
web_host = None
|
||||||
|
@ -11,6 +11,7 @@ import psutil
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from .backend_api.services import tasks, projects
|
from .backend_api.services import tasks, projects
|
||||||
|
from .backend_api.session.session import Session
|
||||||
from .backend_interface.model import Model as BackendModel
|
from .backend_interface.model import Model as BackendModel
|
||||||
from .backend_interface.task import Task as _Task
|
from .backend_interface.task import Task as _Task
|
||||||
from .backend_interface.task.args import _Arguments
|
from .backend_interface.task.args import _Arguments
|
||||||
@ -25,6 +26,7 @@ from .errors import UsageError
|
|||||||
from .logger import Logger
|
from .logger import Logger
|
||||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||||
from .task_parameters import TaskParameters
|
from .task_parameters import TaskParameters
|
||||||
|
from .binding.environ_bind import EnvironmentBind
|
||||||
from .binding.absl_bind import PatchAbsl
|
from .binding.absl_bind import PatchAbsl
|
||||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||||
argparser_update_currenttask
|
argparser_update_currenttask
|
||||||
@ -212,6 +214,7 @@ class Task(_Task):
|
|||||||
Task.__main_task = task
|
Task.__main_task = task
|
||||||
# Patch argparse to be aware of the current task
|
# Patch argparse to be aware of the current task
|
||||||
argparser_update_currenttask(Task.__main_task)
|
argparser_update_currenttask(Task.__main_task)
|
||||||
|
EnvironmentBind.update_current_task(Task.__main_task)
|
||||||
if auto_connect_frameworks:
|
if auto_connect_frameworks:
|
||||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
@ -687,6 +690,26 @@ class Task(_Task):
|
|||||||
self.data.last_iteration = int(last_iteration)
|
self.data.last_iteration = int(last_iteration)
|
||||||
self._edit(last_iteration=self.data.last_iteration)
|
self._edit(last_iteration=self.data.last_iteration)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_credentials(cls, host=None, key=None, secret=None):
|
||||||
|
"""
|
||||||
|
Set new default TRAINS-server host and credentials
|
||||||
|
These configurations will be overridden by wither OS environment variables or trains.conf configuration file
|
||||||
|
Notice: credentials needs to be set prior to Task initialization
|
||||||
|
:param host: host url, example: host='http://localhost:8008'
|
||||||
|
:type host: str
|
||||||
|
:param key: user key/secret pair, example: key='thisisakey123'
|
||||||
|
:type key: str
|
||||||
|
:param secret: user key/secret pair, example: secret='thisisseceret123'
|
||||||
|
:type secret: str
|
||||||
|
"""
|
||||||
|
if host:
|
||||||
|
Session.default_host = host
|
||||||
|
if key:
|
||||||
|
Session.default_key = key
|
||||||
|
if secret:
|
||||||
|
Session.default_secret = secret
|
||||||
|
|
||||||
def _connect_output_model(self, model):
|
def _connect_output_model(self, model):
|
||||||
assert isinstance(model, OutputModel)
|
assert isinstance(model, OutputModel)
|
||||||
model.connect(self)
|
model.connect(self)
|
||||||
|
Loading…
Reference in New Issue
Block a user