Add Task.set_credentials for cloud hosted jupyter support

This commit is contained in:
allegroai 2019-07-13 23:54:47 +03:00
parent cac4ac12b8
commit 7d0bf4838e
6 changed files with 67 additions and 28 deletions

View File

@ -1,14 +1,15 @@
{ {
version: 1.5 version: 1.5
host: https://demoapi.trainsai.io # default https://demoapi.trainsai.io host
host: ""
# verify host ssl certificate, set to False only if you have a very good reason # verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True verify_certificate: True
# default demoapi.trainsai.io credentials # default demoapi.trainsai.io credentials
credentials { credentials {
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" access_key: ""
secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" secret_key: ""
} }
# default version assigned to requests with no specific version. this is not expected to change # default version assigned to requests with no specific version. this is not expected to change

View File

@ -3,9 +3,9 @@ import sys
import types import types
from socket import gethostname from socket import gethostname
import jwt
import requests import requests
import six import six
import jwt
from pyhocon import ConfigTree from pyhocon import ConfigTree
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
@ -36,6 +36,9 @@ class Session(TokenManager):
_session_timeout = (5.0, None) _session_timeout = (5.0, None)
api_version = '2.1' api_version = '2.1'
default_host = "https://demoapi.trainsai.io"
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
# TODO: add requests.codes.gateway_timeout once we support async commits # TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [ _retry_codes = [
@ -94,7 +97,7 @@ class Session(TokenManager):
self._logger = logger self._logger = logger
self.__access_key = api_key or ENV_ACCESS_KEY.get( self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=self.config.get("api.credentials.access_key", None) default=(self.config.get("api.credentials.access_key") or self.default_key)
) )
if not self.access_key: if not self.access_key:
raise ValueError( raise ValueError(
@ -102,14 +105,14 @@ class Session(TokenManager):
) )
self.__secret_key = secret_key or ENV_SECRET_KEY.get( self.__secret_key = secret_key or ENV_SECRET_KEY.get(
default=self.config.get("api.credentials.secret_key", None) default=(self.config.get("api.credentials.secret_key") or self.default_secret)
) )
if not self.secret_key: if not self.secret_key:
raise ValueError( raise ValueError(
"Missing secret_key. Please set in configuration file or pass in session init." "Missing secret_key. Please set in configuration file or pass in session init."
) )
host = host or ENV_HOST.get(default=self.config.get("api.host")) host = host or self.get_api_server_host(config=self.config)
if not host: if not host:
raise ValueError("host is required in init or config") raise ValueError("host is required in init or config")
@ -386,6 +389,13 @@ class Session(TokenManager):
return call_result return call_result
@classmethod
def get_api_server_host(cls, config=None):
if not config:
from ...config import config_obj
config = config_obj
return ENV_HOST.get(default=(config.get("api.host") or cls.default_host))
def _do_refresh_token(self, old_token, exp=None): def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation. """ TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token. Here we ignore the old token and simply obtain a new token.

View File

@ -4,9 +4,10 @@ import requests.exceptions
import six import six
from ..backend_api import Session from ..backend_api import Session
from ..backend_api.session import BatchRequest from ..backend_api.session import BatchRequest
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
from ..config import config_obj from ..config import config_obj
from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY from ..config.defs import LOG_LEVEL_ENV_VAR
from ..debugging import get_logger from ..debugging import get_logger
from ..backend_api.version import __version__ from ..backend_api.version import __version__
from .session import SendError, SessionInterface from .session import SendError, SessionInterface
@ -78,8 +79,8 @@ class InterfaceBase(SessionInterface):
initialize_logging=False, initialize_logging=False,
client='sdk-%s' % __version__, client='sdk-%s' % __version__,
config=config_obj, config=config_obj,
api_key=API_ACCESS_KEY.get(), api_key=ENV_ACCESS_KEY.get(),
secret_key=API_SECRET_KEY.get(), secret_key=ENV_SECRET_KEY.get(),
) )
return InterfaceBase._default_session return InterfaceBase._default_session

View File

@ -9,7 +9,6 @@ from six.moves.urllib.parse import urlparse, urlunparse
import six import six
from ...backend_api.session.defs import ENV_HOST
from ...backend_interface.task.development.worker import DevWorker from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects from ...backend_api.services import tasks, models, events, projects
@ -23,7 +22,7 @@ from ..setupuploadmixin import SetupUploadMixin
from ..util import make_message, get_or_create_project, get_single_result, \ from ..util import make_message, get_or_create_project, get_single_result, \
exact_match_regex exact_match_regex
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \ from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
running_remotely, get_cache_dir, config_obj running_remotely, get_cache_dir
from ...debugging import get_logger from ...debugging import get_logger
from ...debugging.log import LoggerRoot from ...debugging.log import LoggerRoot
from ...storage import StorageHelper from ...storage import StorageHelper
@ -205,8 +204,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# overwrite it before we have a chance to call edit) # overwrite it before we have a chance to call edit)
self._edit(script=result.script) self._edit(script=result.script)
self.reload() self.reload()
if result.script.get('requirements'): self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '')
self._update_requirements(result.script.get('requirements'))
check_package_update_thread.join() check_package_update_thread.join()
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training): def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
@ -673,28 +671,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
app_host = self._get_app_server() app_host = self._get_app_server()
parsed = urlparse(app_host) parsed = urlparse(app_host)
if parsed.port: if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081')) parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1))
elif parsed.netloc.startswith('demoapp.'): elif parsed.netloc.startswith('demoapp.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.')) parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1))
elif parsed.netloc.startswith('app.'):
parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1))
else: else:
parsed = parsed._replace(netloc=parsed.netloc+':8081') parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed) return urlunparse(parsed)
@classmethod @classmethod
def _get_api_server(cls): def _get_api_server(cls):
return ENV_HOST.get(default=config_obj.get("api.host")) return Session.get_api_server_host()
@classmethod @classmethod
def _get_app_server(cls): def _get_app_server(cls):
host = cls._get_api_server() host = cls._get_api_server()
if '://demoapi.' in host: if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.') return host.replace('://demoapi.', '://demoapp.', 1)
if '://api.' in host: if '://api.' in host:
return host.replace('://api.', '://app.') return host.replace('://api.', '://app.', 1)
parsed = urlparse(host) parsed = urlparse(host)
if parsed.port == 8008: if parsed.port == 8008:
return host.replace(':8008', ':8080') return host.replace(':8008', ':8080', 1)
def _edit(self, **kwargs): def _edit(self, **kwargs):
with self._edit_lock: with self._edit_lock:
@ -709,8 +709,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _update_requirements(self, requirements): def _update_requirements(self, requirements):
if not isinstance(requirements, dict): if not isinstance(requirements, dict):
requirements = {'pip': requirements} requirements = {'pip': requirements}
# protection, Old API might not support it
try:
self.data.script.requirements = requirements self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
except Exception:
pass
def _update_script(self, script): def _update_script(self, script):
self.data.script = script self.data.script = script

View File

@ -57,29 +57,29 @@ def main():
if parsed_host.port == 8080: if parsed_host.port == 8080:
# this is a docker 8080 is the web address, we need the api address, it is 8008 # this is a docker 8080 is the web address, we need the api address, it is 8008
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008') print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'): elif parsed_host.netloc.startswith('demoapp.'):
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format( print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
parsed_host.netloc)) parsed_host.netloc))
# this is our demo server # this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.netloc.startswith('app.'): elif parsed_host.netloc.startswith('app.'):
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format( print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
parsed_host.netloc)) parsed_host.netloc))
# this is our application server # this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.port == 8008: elif parsed_host.port == 8008:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path
elif parsed_host.netloc.startswith('demoapi.'): elif parsed_host.netloc.startswith('demoapi.'):
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path
elif parsed_host.netloc.startswith('api.'): elif parsed_host.netloc.startswith('api.'):
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path
else: else:
api_host = None api_host = None
web_host = None web_host = None

View File

@ -11,6 +11,7 @@ import psutil
import six import six
from .backend_api.services import tasks, projects from .backend_api.services import tasks, projects
from .backend_api.session.session import Session
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments from .backend_interface.task.args import _Arguments
@ -25,6 +26,7 @@ from .errors import UsageError
from .logger import Logger from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters from .task_parameters import TaskParameters
from .binding.environ_bind import EnvironmentBind
from .binding.absl_bind import PatchAbsl from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask argparser_update_currenttask
@ -212,6 +214,7 @@ class Task(_Task):
Task.__main_task = task Task.__main_task = task
# Patch argparse to be aware of the current task # Patch argparse to be aware of the current task
argparser_update_currenttask(Task.__main_task) argparser_update_currenttask(Task.__main_task)
EnvironmentBind.update_current_task(Task.__main_task)
if auto_connect_frameworks: if auto_connect_frameworks:
PatchedMatplotlib.update_current_task(Task.__main_task) PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task)
@ -687,6 +690,26 @@ class Task(_Task):
self.data.last_iteration = int(last_iteration) self.data.last_iteration = int(last_iteration)
self._edit(last_iteration=self.data.last_iteration) self._edit(last_iteration=self.data.last_iteration)
@classmethod
def set_credentials(cls, host=None, key=None, secret=None):
"""
Set new default TRAINS-server host and credentials
These configurations will be overridden by wither OS environment variables or trains.conf configuration file
Notice: credentials needs to be set prior to Task initialization
:param host: host url, example: host='http://localhost:8008'
:type host: str
:param key: user key/secret pair, example: key='thisisakey123'
:type key: str
:param secret: user key/secret pair, example: secret='thisisseceret123'
:type secret: str
"""
if host:
Session.default_host = host
if key:
Session.default_key = key
if secret:
Session.default_secret = secret
def _connect_output_model(self, model): def _connect_output_model(self, model):
assert isinstance(model, OutputModel) assert isinstance(model, OutputModel)
model.connect(self) model.connect(self)