mirror of
https://github.com/clearml/clearml
synced 2025-04-05 13:15:17 +00:00
Add Task.set_credentials for cloud hosted jupyter support
This commit is contained in:
parent
cac4ac12b8
commit
7d0bf4838e
@ -1,14 +1,15 @@
|
||||
{
|
||||
version: 1.5
|
||||
host: https://demoapi.trainsai.io
|
||||
# default https://demoapi.trainsai.io host
|
||||
host: ""
|
||||
|
||||
# verify host ssl certificate, set to False only if you have a very good reason
|
||||
verify_certificate: True
|
||||
|
||||
# default demoapi.trainsai.io credentials
|
||||
credentials {
|
||||
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
access_key: ""
|
||||
secret_key: ""
|
||||
}
|
||||
|
||||
# default version assigned to requests with no specific version. this is not expected to change
|
||||
|
@ -3,9 +3,9 @@ import sys
|
||||
import types
|
||||
from socket import gethostname
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
import six
|
||||
import jwt
|
||||
from pyhocon import ConfigTree
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
@ -36,6 +36,9 @@ class Session(TokenManager):
|
||||
_session_timeout = (5.0, None)
|
||||
|
||||
api_version = '2.1'
|
||||
default_host = "https://demoapi.trainsai.io"
|
||||
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
# TODO: add requests.codes.gateway_timeout once we support async commits
|
||||
_retry_codes = [
|
||||
@ -94,7 +97,7 @@ class Session(TokenManager):
|
||||
self._logger = logger
|
||||
|
||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||
default=self.config.get("api.credentials.access_key", None)
|
||||
default=(self.config.get("api.credentials.access_key") or self.default_key)
|
||||
)
|
||||
if not self.access_key:
|
||||
raise ValueError(
|
||||
@ -102,14 +105,14 @@ class Session(TokenManager):
|
||||
)
|
||||
|
||||
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
||||
default=self.config.get("api.credentials.secret_key", None)
|
||||
default=(self.config.get("api.credentials.secret_key") or self.default_secret)
|
||||
)
|
||||
if not self.secret_key:
|
||||
raise ValueError(
|
||||
"Missing secret_key. Please set in configuration file or pass in session init."
|
||||
)
|
||||
|
||||
host = host or ENV_HOST.get(default=self.config.get("api.host"))
|
||||
host = host or self.get_api_server_host(config=self.config)
|
||||
if not host:
|
||||
raise ValueError("host is required in init or config")
|
||||
|
||||
@ -386,6 +389,13 @@ class Session(TokenManager):
|
||||
|
||||
return call_result
|
||||
|
||||
@classmethod
|
||||
def get_api_server_host(cls, config=None):
|
||||
if not config:
|
||||
from ...config import config_obj
|
||||
config = config_obj
|
||||
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
|
||||
|
||||
def _do_refresh_token(self, old_token, exp=None):
|
||||
""" TokenManager abstract method implementation.
|
||||
Here we ignore the old token and simply obtain a new token.
|
||||
|
@ -4,9 +4,10 @@ import requests.exceptions
|
||||
import six
|
||||
from ..backend_api import Session
|
||||
from ..backend_api.session import BatchRequest
|
||||
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||
|
||||
from ..config import config_obj
|
||||
from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY
|
||||
from ..config.defs import LOG_LEVEL_ENV_VAR
|
||||
from ..debugging import get_logger
|
||||
from ..backend_api.version import __version__
|
||||
from .session import SendError, SessionInterface
|
||||
@ -78,8 +79,8 @@ class InterfaceBase(SessionInterface):
|
||||
initialize_logging=False,
|
||||
client='sdk-%s' % __version__,
|
||||
config=config_obj,
|
||||
api_key=API_ACCESS_KEY.get(),
|
||||
secret_key=API_SECRET_KEY.get(),
|
||||
api_key=ENV_ACCESS_KEY.get(),
|
||||
secret_key=ENV_SECRET_KEY.get(),
|
||||
)
|
||||
return InterfaceBase._default_session
|
||||
|
||||
|
@ -9,7 +9,6 @@ from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
import six
|
||||
|
||||
from ...backend_api.session.defs import ENV_HOST
|
||||
from ...backend_interface.task.development.worker import DevWorker
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import tasks, models, events, projects
|
||||
@ -23,7 +22,7 @@ from ..setupuploadmixin import SetupUploadMixin
|
||||
from ..util import make_message, get_or_create_project, get_single_result, \
|
||||
exact_match_regex
|
||||
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
|
||||
running_remotely, get_cache_dir, config_obj
|
||||
running_remotely, get_cache_dir
|
||||
from ...debugging import get_logger
|
||||
from ...debugging.log import LoggerRoot
|
||||
from ...storage import StorageHelper
|
||||
@ -205,8 +204,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# overwrite it before we have a chance to call edit)
|
||||
self._edit(script=result.script)
|
||||
self.reload()
|
||||
if result.script.get('requirements'):
|
||||
self._update_requirements(result.script.get('requirements'))
|
||||
self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '')
|
||||
check_package_update_thread.join()
|
||||
|
||||
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
|
||||
@ -673,28 +671,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
app_host = self._get_app_server()
|
||||
parsed = urlparse(app_host)
|
||||
if parsed.port:
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081'))
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
|
||||
elif parsed.netloc.startswith('demoapp.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.'))
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
|
||||
elif parsed.netloc.startswith('app.'):
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
|
||||
else:
|
||||
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
||||
return urlunparse(parsed)
|
||||
|
||||
@classmethod
|
||||
def _get_api_server(cls):
|
||||
return ENV_HOST.get(default=config_obj.get("api.host"))
|
||||
return Session.get_api_server_host()
|
||||
|
||||
@classmethod
|
||||
def _get_app_server(cls):
|
||||
host = cls._get_api_server()
|
||||
if '://demoapi.' in host:
|
||||
return host.replace('://demoapi.', '://demoapp.')
|
||||
return host.replace('://demoapi.', '://demoapp.', 1)
|
||||
if '://api.' in host:
|
||||
return host.replace('://api.', '://app.')
|
||||
return host.replace('://api.', '://app.', 1)
|
||||
|
||||
parsed = urlparse(host)
|
||||
if parsed.port == 8008:
|
||||
return host.replace(':8008', ':8080')
|
||||
return host.replace(':8008', ':8080', 1)
|
||||
|
||||
def _edit(self, **kwargs):
|
||||
with self._edit_lock:
|
||||
@ -709,8 +709,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def _update_requirements(self, requirements):
|
||||
if not isinstance(requirements, dict):
|
||||
requirements = {'pip': requirements}
|
||||
# protection, Old API might not support it
|
||||
try:
|
||||
self.data.script.requirements = requirements
|
||||
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_script(self, script):
|
||||
self.data.script = script
|
||||
|
@ -57,29 +57,29 @@ def main():
|
||||
if parsed_host.port == 8080:
|
||||
# this is a docker 8080 is the web address, we need the api address, it is 8008
|
||||
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapp.'):
|
||||
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
|
||||
parsed_host.netloc))
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
|
||||
parsed_host.netloc))
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.port == 8008:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapi.'):
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
|
||||
else:
|
||||
api_host = None
|
||||
web_host = None
|
||||
|
@ -11,6 +11,7 @@ import psutil
|
||||
import six
|
||||
|
||||
from .backend_api.services import tasks, projects
|
||||
from .backend_api.session.session import Session
|
||||
from .backend_interface.model import Model as BackendModel
|
||||
from .backend_interface.task import Task as _Task
|
||||
from .backend_interface.task.args import _Arguments
|
||||
@ -25,6 +26,7 @@ from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .task_parameters import TaskParameters
|
||||
from .binding.environ_bind import EnvironmentBind
|
||||
from .binding.absl_bind import PatchAbsl
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
@ -212,6 +214,7 @@ class Task(_Task):
|
||||
Task.__main_task = task
|
||||
# Patch argparse to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
EnvironmentBind.update_current_task(Task.__main_task)
|
||||
if auto_connect_frameworks:
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
@ -687,6 +690,26 @@ class Task(_Task):
|
||||
self.data.last_iteration = int(last_iteration)
|
||||
self._edit(last_iteration=self.data.last_iteration)
|
||||
|
||||
@classmethod
|
||||
def set_credentials(cls, host=None, key=None, secret=None):
|
||||
"""
|
||||
Set new default TRAINS-server host and credentials
|
||||
These configurations will be overridden by wither OS environment variables or trains.conf configuration file
|
||||
Notice: credentials needs to be set prior to Task initialization
|
||||
:param host: host url, example: host='http://localhost:8008'
|
||||
:type host: str
|
||||
:param key: user key/secret pair, example: key='thisisakey123'
|
||||
:type key: str
|
||||
:param secret: user key/secret pair, example: secret='thisisseceret123'
|
||||
:type secret: str
|
||||
"""
|
||||
if host:
|
||||
Session.default_host = host
|
||||
if key:
|
||||
Session.default_key = key
|
||||
if secret:
|
||||
Session.default_secret = secret
|
||||
|
||||
def _connect_output_model(self, model):
|
||||
assert isinstance(model, OutputModel)
|
||||
model.connect(self)
|
||||
|
Loading…
Reference in New Issue
Block a user