Add Task.set_credentials for cloud hosted jupyter support

This commit is contained in:
allegroai 2019-07-13 23:54:47 +03:00
parent cac4ac12b8
commit 7d0bf4838e
6 changed files with 67 additions and 28 deletions

View File

@ -1,14 +1,15 @@
{
version: 1.5
host: https://demoapi.trainsai.io
# default https://demoapi.trainsai.io host
host: ""
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
# default demoapi.trainsai.io credentials
credentials {
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
access_key: ""
secret_key: ""
}
# default version assigned to requests with no specific version. this is not expected to change

View File

@ -3,9 +3,9 @@ import sys
import types
from socket import gethostname
import jwt
import requests
import six
import jwt
from pyhocon import ConfigTree
from requests.auth import HTTPBasicAuth
@ -36,6 +36,9 @@ class Session(TokenManager):
_session_timeout = (5.0, None)
api_version = '2.1'
default_host = "https://demoapi.trainsai.io"
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
# TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [
@ -94,7 +97,7 @@ class Session(TokenManager):
self._logger = logger
self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=self.config.get("api.credentials.access_key", None)
default=(self.config.get("api.credentials.access_key") or self.default_key)
)
if not self.access_key:
raise ValueError(
@ -102,14 +105,14 @@ class Session(TokenManager):
)
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
default=self.config.get("api.credentials.secret_key", None)
default=(self.config.get("api.credentials.secret_key") or self.default_secret)
)
if not self.secret_key:
raise ValueError(
"Missing secret_key. Please set in configuration file or pass in session init."
)
host = host or ENV_HOST.get(default=self.config.get("api.host"))
host = host or self.get_api_server_host(config=self.config)
if not host:
raise ValueError("host is required in init or config")
@ -386,6 +389,13 @@ class Session(TokenManager):
return call_result
@classmethod
def get_api_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token.

View File

@ -4,9 +4,10 @@ import requests.exceptions
import six
from ..backend_api import Session
from ..backend_api.session import BatchRequest
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
from ..config import config_obj
from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY
from ..config.defs import LOG_LEVEL_ENV_VAR
from ..debugging import get_logger
from ..backend_api.version import __version__
from .session import SendError, SessionInterface
@ -78,8 +79,8 @@ class InterfaceBase(SessionInterface):
initialize_logging=False,
client='sdk-%s' % __version__,
config=config_obj,
api_key=API_ACCESS_KEY.get(),
secret_key=API_SECRET_KEY.get(),
api_key=ENV_ACCESS_KEY.get(),
secret_key=ENV_SECRET_KEY.get(),
)
return InterfaceBase._default_session

View File

@ -9,7 +9,6 @@ from six.moves.urllib.parse import urlparse, urlunparse
import six
from ...backend_api.session.defs import ENV_HOST
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
@ -23,7 +22,7 @@ from ..setupuploadmixin import SetupUploadMixin
from ..util import make_message, get_or_create_project, get_single_result, \
exact_match_regex
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
running_remotely, get_cache_dir, config_obj
running_remotely, get_cache_dir
from ...debugging import get_logger
from ...debugging.log import LoggerRoot
from ...storage import StorageHelper
@ -205,8 +204,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
if result.script.get('requirements'):
self._update_requirements(result.script.get('requirements'))
self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '')
check_package_update_thread.join()
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
@ -673,28 +671,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
app_host = self._get_app_server()
parsed = urlparse(app_host)
if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081'))
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.'))
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else:
parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed)
@classmethod
def _get_api_server(cls):
return ENV_HOST.get(default=config_obj.get("api.host"))
return Session.get_api_server_host()
@classmethod
def _get_app_server(cls):
host = cls._get_api_server()
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.')
return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host:
return host.replace('://api.', '://app.')
return host.replace('://api.', '://app.', 1)
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080')
return host.replace(':8008', ':8080', 1)
def _edit(self, **kwargs):
with self._edit_lock:
@ -709,8 +709,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _update_requirements(self, requirements):
if not isinstance(requirements, dict):
requirements = {'pip': requirements}
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
# protection, Old API might not support it
try:
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
except Exception:
pass
def _update_script(self, script):
self.data.script = script

View File

@ -57,29 +57,29 @@ def main():
if parsed_host.port == 8080:
# this is a docker 8080 is the web address, we need the api address, it is 8008
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'):
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
parsed_host.netloc))
# this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.netloc.startswith('app.'):
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
parsed_host.netloc))
# this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.port == 8008:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapi.'):
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('api.'):
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
else:
api_host = None
web_host = None

View File

@ -11,6 +11,7 @@ import psutil
import six
from .backend_api.services import tasks, projects
from .backend_api.session.session import Session
from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments
@ -25,6 +26,7 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .binding.environ_bind import EnvironmentBind
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
@ -212,6 +214,7 @@ class Task(_Task):
Task.__main_task = task
# Patch argparse to be aware of the current task
argparser_update_currenttask(Task.__main_task)
EnvironmentBind.update_current_task(Task.__main_task)
if auto_connect_frameworks:
PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task)
@ -687,6 +690,26 @@ class Task(_Task):
self.data.last_iteration = int(last_iteration)
self._edit(last_iteration=self.data.last_iteration)
@classmethod
def set_credentials(cls, host=None, key=None, secret=None):
"""
Set new default TRAINS-server host and credentials
These configurations will be overridden by wither OS environment variables or trains.conf configuration file
Notice: credentials needs to be set prior to Task initialization
:param host: host url, example: host='http://localhost:8008'
:type host: str
:param key: user key/secret pair, example: key='thisisakey123'
:type key: str
:param secret: user key/secret pair, example: secret='thisisseceret123'
:type secret: str
"""
if host:
Session.default_host = host
if key:
Session.default_key = key
if secret:
Session.default_secret = secret
def _connect_output_model(self, model):
assert isinstance(model, OutputModel)
model.connect(self)