From 675dc32528bbd44ae50acccd8a88c4f45d7c9ef6 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 19 Jun 2019 15:14:28 +0300 Subject: [PATCH] Add api.verify_certificate to allow non-verified ssl connection (enterprise firewall mitm scenarios), by defaults only secured connections are allowed. Add latest package check, as long as we are in pre-release. --- docs/faq.md | 2 +- docs/trains.conf | 4 ++ requirements.txt | 6 +-- trains/backend_api/config/default/api.conf | 3 ++ trains/backend_api/session/defs.py | 1 + trains/backend_api/utils.py | 29 +++++++++++-- trains/backend_interface/task/task.py | 48 ++++++++++++++++++---- trains/task.py | 22 +++++++--- trains/utilities/async_manager.py | 2 +- 9 files changed, 94 insertions(+), 23 deletions(-) diff --git a/docs/faq.md b/docs/faq.md index 9ff16652..50effc1e 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -118,7 +118,7 @@ taks = Task.init(project_name, task_name, output_uri="gs://bucket-name/folder") ``` **NOTE:** These require configuring the storage credentials in `~/trains.conf`. -For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L51). +For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L55). ## I am training multiple models at the same time, but I only see one of them. What happened? diff --git a/docs/trains.conf b/docs/trains.conf index 4387b798..817087a5 100644 --- a/docs/trains.conf +++ b/docs/trains.conf @@ -2,8 +2,12 @@ api { # Notice: 'host' is the api server (default port 8008), not the web server. host: http://localhost:8008 + # Credentials are generated in the webapp, http://localhost:8080/admin credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"} + + # verify host ssl certificate, set to False only if you have a very good reason + verify_certificate: True } sdk { # TRAINS - default SDK configuration diff --git a/requirements.txt b/requirements.txt index 0de50a93..89208f01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,16 +17,16 @@ jsonschema>=2.6.0 numpy>=1.10 opencv-python>=3.2.0.8 pathlib2>=2.3.0 +plotly>=3.9.0 psutil>=3.4.2 pyhocon>=0.3.38 python-dateutil>=2.6.1 +pyjwt>=1.6.4 PyYAML>=3.12 requests-file>=1.4.2 requests>=2.18.4 six>=1.11.0 tqdm>=4.19.5 +typing>=3.6.4 urllib3>=1.22 watchdog>=0.8.0 -pyjwt>=1.6.4 -plotly>=3.9.0 -typing>=3.6.4 diff --git a/trains/backend_api/config/default/api.conf b/trains/backend_api/config/default/api.conf index 1ba191c9..81c3c39f 100644 --- a/trains/backend_api/config/default/api.conf +++ b/trains/backend_api/config/default/api.conf @@ -2,6 +2,9 @@ version: 1.5 host: https://demoapi.trainsai.io + # verify host ssl certificate, set to False only if you have a very good reason + verify_certificate: True + # default demoapi.trainsai.io credentials credentials { access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" diff --git a/trains/backend_api/session/defs.py b/trains/backend_api/session/defs.py index 2b0f1be3..73fc6603 100644 --- a/trains/backend_api/session/defs.py +++ b/trains/backend_api/session/defs.py @@ -5,3 +5,4 @@ ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST") ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY") ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY") ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False) +ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True) diff --git a/trains/backend_api/utils.py b/trains/backend_api/utils.py index e2756798..da25e129 100644 --- a/trains/backend_api/utils.py +++ b/trains/backend_api/utils.py @@ -1,3 +1,4 @@ +import logging import ssl import sys @@ -8,6 +9,8 @@ from urllib3.util import Retry from urllib3 import PoolManager import six +from .session.defs import ENV_HOST_VERIFY_CERT + if six.PY3: from functools import lru_cache elif six.PY2: @@ -15,12 +18,13 @@ elif six.PY2: from backports.functools_lru_cache import lru_cache +__disable_certificate_verification_warning = 0 + + @lru_cache() def get_config(): - from ..backend_config import Config - config = Config(verbose=False) - config.reload() - return config + from ..config import config_obj + return config_obj class TLSv1HTTPAdapter(HTTPAdapter): @@ -42,6 +46,7 @@ def get_http_session_with_retry( backoff_max=None, pool_connections=None, pool_maxsize=None): + global __disable_certificate_verification_warning if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)): raise ValueError('Bad configuration. All retry count values must be null or int') @@ -72,6 +77,22 @@ def get_http_session_with_retry( adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize) session.mount('http://', adapter) session.mount('https://', adapter) + # update verify host certiface + session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True)) + if not session.verify and __disable_certificate_verification_warning < 2: + # show warning + __disable_certificate_verification_warning += 1 + logging.getLogger('TRAINS').warning( + msg='InsecureRequestWarning: Certificate verification is disabled! Adding ' + 'certificate verification is strongly advised. See: ' + 'https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings') + # make sure we only do not see the warning + import urllib3 + # noinspection PyBroadException + try: + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + except Exception: + pass return session diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index c36d8f8c..152fe992 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -2,11 +2,13 @@ import collections import itertools import logging +from threading import RLock, Thread from copy import copy from six.moves.urllib.parse import urlparse, urlunparse import six +from ...backend_api.session.defs import ENV_HOST from ...backend_interface.task.development.worker import DevWorker from ...backend_api import Session from ...backend_api.services import tasks, models, events, projects @@ -83,6 +85,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._parameters_allowed_types = ( six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None)) ) + self._edit_lock = RLock() if not task_id: # generate a new task @@ -172,6 +175,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): return task_id def _update_repository(self): + def check_package_update(): + # check latest version + from ...utilities.check_updates import CheckPackageUpdates + latest_version = CheckPackageUpdates.check_new_package_available(only_once=True) + if latest_version: + if not latest_version[1]: + self.get_logger().console( + 'TRAINS new package available: UPGRADE to v{} is recommended!'.format( + latest_version[0]), + ) + else: + self.get_logger().console( + 'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format( + latest_version[0]), + ) + + check_package_update_thread = Thread(target=check_package_update) + check_package_update_thread.start() result = ScriptInfo.get(log=self.log) for msg in result.warning_messages: self.get_logger().console(msg) @@ -180,6 +201,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # Since we might run asynchronously, don't use self.data (lest someone else # overwrite it before we have a chance to call edit) self._edit(script=result.script) + self.reload() + check_package_update_thread.join() def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training): created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s') @@ -335,8 +358,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _reload(self): """ Reload the task object from the backend """ - res = self.send(tasks.GetByIdRequest(task=self.id)) - return res.response.task + with self._edit_lock: + res = self.send(tasks.GetByIdRequest(task=self.id)) + return res.response.task def reset(self, set_started_on_success=True): """ Reset the task. Task will be reloaded following a successful reset. """ @@ -645,8 +669,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): parsed = parsed._replace(netloc=parsed.netloc+':8081') return urlunparse(parsed) - def _get_app_server(self): - host = config_obj.get('api.host') + @classmethod + def _get_api_server(cls): + return ENV_HOST.get(default=config_obj.get("api.host")) + + @classmethod + def _get_app_server(cls): + host = cls._get_api_server() if '://demoapi.' in host: return host.replace('://demoapi.', '://demoapp.') if '://api.' in host: @@ -657,12 +686,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): return host.replace(':8008', ':8080') def _edit(self, **kwargs): - # Since we ae using forced update, make sure he task status is valid - if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)): - raise ValueError('Task object can only be updated if created or in_progress') + with self._edit_lock: + # Since we ae using forced update, make sure he task status is valid + if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)): + raise ValueError('Task object can only be updated if created or in_progress') - res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) - return res + res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) + return res @classmethod def create_new_task(cls, session, task_entry, log=None): diff --git a/trains/task.py b/trains/task.py index 0eb45dcc..74576653 100644 --- a/trains/task.py +++ b/trains/task.py @@ -9,7 +9,6 @@ from collections import OrderedDict, Callable import psutil import six -from six.moves._thread import start_new_thread from .backend_api.services import tasks, projects from .backend_interface import TaskStatusEnum @@ -100,6 +99,8 @@ class Task(_Task): self._dev_stop_signal = None self._dev_mode_periodic_flag = False self._connected_parameter_type = None + self._detect_repo_async_thread = None + self._lock = threading.RLock() # register atexit, so that we mark the task as stopped self._at_exit_called = False self.__register_at_exit(self._at_exit) @@ -391,11 +392,12 @@ class Task(_Task): # update current repository and put warning into logs if in_dev_mode and cls.__detect_repo_async: - start_new_thread(task._update_repository, tuple()) + task._detect_repo_async_thread = threading.Thread(target=task._update_repository) + task._detect_repo_async_thread.start() else: task._update_repository() - # show the debug metrics page in the log, it is very convinient + # show the debug metrics page in the log, it is very convenient logger.console( 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( task._get_app_server(), @@ -543,6 +545,16 @@ class Task(_Task): """ self._dev_mode_periodic() + # wait for detection repo sync + if self._detect_repo_async_thread: + with self._lock: + if self._detect_repo_async_thread: + try: + self._detect_repo_async_thread.join() + self._detect_repo_async_thread = None + except Exception: + pass + # make sure model upload is done if BackendModel.get_num_results() > 0 and wait_for_uploads: BackendModel.wait_for_results() @@ -1086,7 +1098,7 @@ class Task(_Task): @classmethod def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type): - hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type) + hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) # check if we have a cached task_id we can reuse # it must be from within the last 24h and with the same project/name/type @@ -1111,7 +1123,7 @@ class Task(_Task): @classmethod def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id): - hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type) + hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) task_id = str(task_id) # update task session cache diff --git a/trains/utilities/async_manager.py b/trains/utilities/async_manager.py index cd0efdc1..ab596b31 100644 --- a/trains/utilities/async_manager.py +++ b/trains/utilities/async_manager.py @@ -36,7 +36,7 @@ class AsyncManagerMixin(object): if r.ready(): continue t = time.time() - r.wait(timeout=remaining) + r.wait(timeout=remaining or 2.0) count += 1 if max_num_uploads is not None and max_num_uploads - count <= 0: break