Add api.verify_certificate to allow non-verified ssl connection (enterprise firewall mitm scenarios), by defaults only secured connections are allowed.

Add latest package check, as long as we are in pre-release.
This commit is contained in:
allegroai 2019-06-19 15:14:28 +03:00
parent fdc767c1c5
commit 675dc32528
9 changed files with 94 additions and 23 deletions

View File

@ -118,7 +118,7 @@ taks = Task.init(project_name, task_name, output_uri="gs://bucket-name/folder")
```
**NOTE:** These require configuring the storage credentials in `~/trains.conf`.
For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L51).
For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L55).
## I am training multiple models at the same time, but I only see one of them. What happened? <a name="only-last-model-appears"></a>

View File

@ -2,8 +2,12 @@
api {
# Notice: 'host' is the api server (default port 8008), not the web server.
host: http://localhost:8008
# Credentials are generated in the webapp, http://localhost:8080/admin
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
}
sdk {
# TRAINS - default SDK configuration

View File

@ -17,16 +17,16 @@ jsonschema>=2.6.0
numpy>=1.10
opencv-python>=3.2.0.8
pathlib2>=2.3.0
plotly>=3.9.0
psutil>=3.4.2
pyhocon>=0.3.38
python-dateutil>=2.6.1
pyjwt>=1.6.4
PyYAML>=3.12
requests-file>=1.4.2
requests>=2.18.4
six>=1.11.0
tqdm>=4.19.5
typing>=3.6.4
urllib3>=1.22
watchdog>=0.8.0
pyjwt>=1.6.4
plotly>=3.9.0
typing>=3.6.4

View File

@ -2,6 +2,9 @@
version: 1.5
host: https://demoapi.trainsai.io
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
# default demoapi.trainsai.io credentials
credentials {
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"

View File

@ -5,3 +5,4 @@ ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True)

View File

@ -1,3 +1,4 @@
import logging
import ssl
import sys
@ -8,6 +9,8 @@ from urllib3.util import Retry
from urllib3 import PoolManager
import six
from .session.defs import ENV_HOST_VERIFY_CERT
if six.PY3:
from functools import lru_cache
elif six.PY2:
@ -15,12 +18,13 @@ elif six.PY2:
from backports.functools_lru_cache import lru_cache
__disable_certificate_verification_warning = 0
@lru_cache()
def get_config():
from ..backend_config import Config
config = Config(verbose=False)
config.reload()
return config
from ..config import config_obj
return config_obj
class TLSv1HTTPAdapter(HTTPAdapter):
@ -42,6 +46,7 @@ def get_http_session_with_retry(
backoff_max=None,
pool_connections=None,
pool_maxsize=None):
global __disable_certificate_verification_warning
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
raise ValueError('Bad configuration. All retry count values must be null or int')
@ -72,6 +77,22 @@ def get_http_session_with_retry(
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
session.mount('http://', adapter)
session.mount('https://', adapter)
# update verify host certiface
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
if not session.verify and __disable_certificate_verification_warning < 2:
# show warning
__disable_certificate_verification_warning += 1
logging.getLogger('TRAINS').warning(
msg='InsecureRequestWarning: Certificate verification is disabled! Adding '
'certificate verification is strongly advised. See: '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings')
# make sure we only do not see the warning
import urllib3
# noinspection PyBroadException
try:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
except Exception:
pass
return session

View File

@ -2,11 +2,13 @@
import collections
import itertools
import logging
from threading import RLock, Thread
from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse
import six
from ...backend_api.session.defs import ENV_HOST
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
@ -83,6 +85,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
)
self._edit_lock = RLock()
if not task_id:
# generate a new task
@ -172,6 +175,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return task_id
def _update_repository(self):
def check_package_update():
# check latest version
from ...utilities.check_updates import CheckPackageUpdates
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
self.get_logger().console(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
)
else:
self.get_logger().console(
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
latest_version[0]),
)
check_package_update_thread = Thread(target=check_package_update)
check_package_update_thread.start()
result = ScriptInfo.get(log=self.log)
for msg in result.warning_messages:
self.get_logger().console(msg)
@ -180,6 +201,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Since we might run asynchronously, don't use self.data (lest someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
check_package_update_thread.join()
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
@ -335,8 +358,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _reload(self):
""" Reload the task object from the backend """
res = self.send(tasks.GetByIdRequest(task=self.id))
return res.response.task
with self._edit_lock:
res = self.send(tasks.GetByIdRequest(task=self.id))
return res.response.task
def reset(self, set_started_on_success=True):
""" Reset the task. Task will be reloaded following a successful reset. """
@ -645,8 +669,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed)
def _get_app_server(self):
host = config_obj.get('api.host')
@classmethod
def _get_api_server(cls):
return ENV_HOST.get(default=config_obj.get("api.host"))
@classmethod
def _get_app_server(cls):
host = cls._get_api_server()
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.')
if '://api.' in host:
@ -657,12 +686,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return host.replace(':8008', ':8080')
def _edit(self, **kwargs):
# Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
raise ValueError('Task object can only be updated if created or in_progress')
with self._edit_lock:
# Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
return res
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
return res
@classmethod
def create_new_task(cls, session, task_entry, log=None):

View File

@ -9,7 +9,6 @@ from collections import OrderedDict, Callable
import psutil
import six
from six.moves._thread import start_new_thread
from .backend_api.services import tasks, projects
from .backend_interface import TaskStatusEnum
@ -100,6 +99,8 @@ class Task(_Task):
self._dev_stop_signal = None
self._dev_mode_periodic_flag = False
self._connected_parameter_type = None
self._detect_repo_async_thread = None
self._lock = threading.RLock()
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
self.__register_at_exit(self._at_exit)
@ -391,11 +392,12 @@ class Task(_Task):
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
start_new_thread(task._update_repository, tuple())
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.start()
else:
task._update_repository()
# show the debug metrics page in the log, it is very convinient
# show the debug metrics page in the log, it is very convenient
logger.console(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(),
@ -543,6 +545,16 @@ class Task(_Task):
"""
self._dev_mode_periodic()
# wait for detection repo sync
if self._detect_repo_async_thread:
with self._lock:
if self._detect_repo_async_thread:
try:
self._detect_repo_async_thread.join()
self._detect_repo_async_thread = None
except Exception:
pass
# make sure model upload is done
if BackendModel.get_num_results() > 0 and wait_for_uploads:
BackendModel.wait_for_results()
@ -1086,7 +1098,7 @@ class Task(_Task):
@classmethod
def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type):
hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type)
hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
# check if we have a cached task_id we can reuse
# it must be from within the last 24h and with the same project/name/type
@ -1111,7 +1123,7 @@ class Task(_Task):
@classmethod
def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id):
hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type)
hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
task_id = str(task_id)
# update task session cache

View File

@ -36,7 +36,7 @@ class AsyncManagerMixin(object):
if r.ready():
continue
t = time.time()
r.wait(timeout=remaining)
r.wait(timeout=remaining or 2.0)
count += 1
if max_num_uploads is not None and max_num_uploads - count <= 0:
break