From 7d0bf4838e9129b07319c6ea94d379d31b577210 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 13 Jul 2019 23:54:47 +0300 Subject: [PATCH] Add Task.set_credentials for cloud hosted jupyter support --- trains/backend_api/config/default/api.conf | 7 +++--- trains/backend_api/session/session.py | 18 ++++++++++---- trains/backend_interface/base.py | 7 +++--- trains/backend_interface/task/task.py | 28 ++++++++++++---------- trains/config/default/__main__.py | 12 +++++----- trains/task.py | 23 ++++++++++++++++++ 6 files changed, 67 insertions(+), 28 deletions(-) diff --git a/trains/backend_api/config/default/api.conf b/trains/backend_api/config/default/api.conf index 81c3c39f..cbac9189 100644 --- a/trains/backend_api/config/default/api.conf +++ b/trains/backend_api/config/default/api.conf @@ -1,14 +1,15 @@ { version: 1.5 - host: https://demoapi.trainsai.io + # default https://demoapi.trainsai.io host + host: "" # verify host ssl certificate, set to False only if you have a very good reason verify_certificate: True # default demoapi.trainsai.io credentials credentials { - access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" - secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" + access_key: "" + secret_key: "" } # default version assigned to requests with no specific version. this is not expected to change diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index 1497fdbe..da71098e 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -3,9 +3,9 @@ import sys import types from socket import gethostname +import jwt import requests import six -import jwt from pyhocon import ConfigTree from requests.auth import HTTPBasicAuth @@ -36,6 +36,9 @@ class Session(TokenManager): _session_timeout = (5.0, None) api_version = '2.1' + default_host = "https://demoapi.trainsai.io" + default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" + default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" # TODO: add requests.codes.gateway_timeout once we support async commits _retry_codes = [ @@ -94,7 +97,7 @@ class Session(TokenManager): self._logger = logger self.__access_key = api_key or ENV_ACCESS_KEY.get( - default=self.config.get("api.credentials.access_key", None) + default=(self.config.get("api.credentials.access_key") or self.default_key) ) if not self.access_key: raise ValueError( @@ -102,14 +105,14 @@ class Session(TokenManager): ) self.__secret_key = secret_key or ENV_SECRET_KEY.get( - default=self.config.get("api.credentials.secret_key", None) + default=(self.config.get("api.credentials.secret_key") or self.default_secret) ) if not self.secret_key: raise ValueError( "Missing secret_key. Please set in configuration file or pass in session init." ) - host = host or ENV_HOST.get(default=self.config.get("api.host")) + host = host or self.get_api_server_host(config=self.config) if not host: raise ValueError("host is required in init or config") @@ -386,6 +389,13 @@ class Session(TokenManager): return call_result + @classmethod + def get_api_server_host(cls, config=None): + if not config: + from ...config import config_obj + config = config_obj + return ENV_HOST.get(default=(config.get("api.host") or cls.default_host)) + def _do_refresh_token(self, old_token, exp=None): """ TokenManager abstract method implementation. Here we ignore the old token and simply obtain a new token. diff --git a/trains/backend_interface/base.py b/trains/backend_interface/base.py index 337de106..b2463c43 100644 --- a/trains/backend_interface/base.py +++ b/trains/backend_interface/base.py @@ -4,9 +4,10 @@ import requests.exceptions import six from ..backend_api import Session from ..backend_api.session import BatchRequest +from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY from ..config import config_obj -from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY +from ..config.defs import LOG_LEVEL_ENV_VAR from ..debugging import get_logger from ..backend_api.version import __version__ from .session import SendError, SessionInterface @@ -78,8 +79,8 @@ class InterfaceBase(SessionInterface): initialize_logging=False, client='sdk-%s' % __version__, config=config_obj, - api_key=API_ACCESS_KEY.get(), - secret_key=API_SECRET_KEY.get(), + api_key=ENV_ACCESS_KEY.get(), + secret_key=ENV_SECRET_KEY.get(), ) return InterfaceBase._default_session diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index b63e70e2..f207f1d9 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -9,7 +9,6 @@ from six.moves.urllib.parse import urlparse, urlunparse import six -from ...backend_api.session.defs import ENV_HOST from ...backend_interface.task.development.worker import DevWorker from ...backend_api import Session from ...backend_api.services import tasks, models, events, projects @@ -23,7 +22,7 @@ from ..setupuploadmixin import SetupUploadMixin from ..util import make_message, get_or_create_project, get_single_result, \ exact_match_regex from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \ - running_remotely, get_cache_dir, config_obj + running_remotely, get_cache_dir from ...debugging import get_logger from ...debugging.log import LoggerRoot from ...storage import StorageHelper @@ -205,8 +204,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # overwrite it before we have a chance to call edit) self._edit(script=result.script) self.reload() - if result.script.get('requirements'): - self._update_requirements(result.script.get('requirements')) + self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '') check_package_update_thread.join() def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training): @@ -673,28 +671,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): app_host = self._get_app_server() parsed = urlparse(app_host) if parsed.port: - parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081')) + parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1)) elif parsed.netloc.startswith('demoapp.'): - parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.')) + parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1)) + elif parsed.netloc.startswith('app.'): + parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1)) else: parsed = parsed._replace(netloc=parsed.netloc+':8081') return urlunparse(parsed) @classmethod def _get_api_server(cls): - return ENV_HOST.get(default=config_obj.get("api.host")) + return Session.get_api_server_host() @classmethod def _get_app_server(cls): host = cls._get_api_server() if '://demoapi.' in host: - return host.replace('://demoapi.', '://demoapp.') + return host.replace('://demoapi.', '://demoapp.', 1) if '://api.' in host: - return host.replace('://api.', '://app.') + return host.replace('://api.', '://app.', 1) parsed = urlparse(host) if parsed.port == 8008: - return host.replace(':8008', ':8080') + return host.replace(':8008', ':8080', 1) def _edit(self, **kwargs): with self._edit_lock: @@ -709,8 +709,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _update_requirements(self, requirements): if not isinstance(requirements, dict): requirements = {'pip': requirements} - self.data.script.requirements = requirements - self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) + # protection, Old API might not support it + try: + self.data.script.requirements = requirements + self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) + except Exception: + pass def _update_script(self, script): self.data.script = script diff --git a/trains/config/default/__main__.py b/trains/config/default/__main__.py index 1ea94aae..aa450008 100644 --- a/trains/config/default/__main__.py +++ b/trains/config/default/__main__.py @@ -57,29 +57,29 @@ def main(): if parsed_host.port == 8080: # this is a docker 8080 is the web address, we need the api address, it is 8008 print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008') - api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path + api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path elif parsed_host.netloc.startswith('demoapp.'): print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format( parsed_host.netloc)) # this is our demo server - api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path + api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path elif parsed_host.netloc.startswith('app.'): print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format( parsed_host.netloc)) # this is our application server - api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path + api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path elif parsed_host.port == 8008: api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path elif parsed_host.netloc.startswith('demoapi.'): api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path elif parsed_host.netloc.startswith('api.'): api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path - web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path + web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path else: api_host = None web_host = None diff --git a/trains/task.py b/trains/task.py index 0e986d92..1952a9a3 100644 --- a/trains/task.py +++ b/trains/task.py @@ -11,6 +11,7 @@ import psutil import six from .backend_api.services import tasks, projects +from .backend_api.session.session import Session from .backend_interface.model import Model as BackendModel from .backend_interface.task import Task as _Task from .backend_interface.task.args import _Arguments @@ -25,6 +26,7 @@ from .errors import UsageError from .logger import Logger from .model import InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters +from .binding.environ_bind import EnvironmentBind from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask @@ -212,6 +214,7 @@ class Task(_Task): Task.__main_task = task # Patch argparse to be aware of the current task argparser_update_currenttask(Task.__main_task) + EnvironmentBind.update_current_task(Task.__main_task) if auto_connect_frameworks: PatchedMatplotlib.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task) @@ -687,6 +690,26 @@ class Task(_Task): self.data.last_iteration = int(last_iteration) self._edit(last_iteration=self.data.last_iteration) + @classmethod + def set_credentials(cls, host=None, key=None, secret=None): + """ + Set new default TRAINS-server host and credentials + These configurations will be overridden by wither OS environment variables or trains.conf configuration file + Notice: credentials needs to be set prior to Task initialization + :param host: host url, example: host='http://localhost:8008' + :type host: str + :param key: user key/secret pair, example: key='thisisakey123' + :type key: str + :param secret: user key/secret pair, example: secret='thisisseceret123' + :type secret: str + """ + if host: + Session.default_host = host + if key: + Session.default_key = key + if secret: + Session.default_secret = secret + def _connect_output_model(self, model): assert isinstance(model, OutputModel) model.connect(self)