Add api.verify_certificate to allow non-verified ssl connection (enterprise firewall mitm scenarios), by defaults only secured connections are allowed.

Add latest package check, as long as we are in pre-release.
This commit is contained in:
allegroai 2019-06-19 15:14:28 +03:00
parent fdc767c1c5
commit 675dc32528
9 changed files with 94 additions and 23 deletions

View File

@ -118,7 +118,7 @@ taks = Task.init(project_name, task_name, output_uri="gs://bucket-name/folder")
``` ```
**NOTE:** These require configuring the storage credentials in `~/trains.conf`. **NOTE:** These require configuring the storage credentials in `~/trains.conf`.
For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L51). For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L55).
## I am training multiple models at the same time, but I only see one of them. What happened? <a name="only-last-model-appears"></a> ## I am training multiple models at the same time, but I only see one of them. What happened? <a name="only-last-model-appears"></a>

View File

@ -2,8 +2,12 @@
api { api {
# Notice: 'host' is the api server (default port 8008), not the web server. # Notice: 'host' is the api server (default port 8008), not the web server.
host: http://localhost:8008 host: http://localhost:8008
# Credentials are generated in the webapp, http://localhost:8080/admin # Credentials are generated in the webapp, http://localhost:8080/admin
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"} credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
} }
sdk { sdk {
# TRAINS - default SDK configuration # TRAINS - default SDK configuration

View File

@ -17,16 +17,16 @@ jsonschema>=2.6.0
numpy>=1.10 numpy>=1.10
opencv-python>=3.2.0.8 opencv-python>=3.2.0.8
pathlib2>=2.3.0 pathlib2>=2.3.0
plotly>=3.9.0
psutil>=3.4.2 psutil>=3.4.2
pyhocon>=0.3.38 pyhocon>=0.3.38
python-dateutil>=2.6.1 python-dateutil>=2.6.1
pyjwt>=1.6.4
PyYAML>=3.12 PyYAML>=3.12
requests-file>=1.4.2 requests-file>=1.4.2
requests>=2.18.4 requests>=2.18.4
six>=1.11.0 six>=1.11.0
tqdm>=4.19.5 tqdm>=4.19.5
typing>=3.6.4
urllib3>=1.22 urllib3>=1.22
watchdog>=0.8.0 watchdog>=0.8.0
pyjwt>=1.6.4
plotly>=3.9.0
typing>=3.6.4

View File

@ -2,6 +2,9 @@
version: 1.5 version: 1.5
host: https://demoapi.trainsai.io host: https://demoapi.trainsai.io
# verify host ssl certificate, set to False only if you have a very good reason
verify_certificate: True
# default demoapi.trainsai.io credentials # default demoapi.trainsai.io credentials
credentials { credentials {
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"

View File

@ -5,3 +5,4 @@ ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY") ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY") ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False) ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True)

View File

@ -1,3 +1,4 @@
import logging
import ssl import ssl
import sys import sys
@ -8,6 +9,8 @@ from urllib3.util import Retry
from urllib3 import PoolManager from urllib3 import PoolManager
import six import six
from .session.defs import ENV_HOST_VERIFY_CERT
if six.PY3: if six.PY3:
from functools import lru_cache from functools import lru_cache
elif six.PY2: elif six.PY2:
@ -15,12 +18,13 @@ elif six.PY2:
from backports.functools_lru_cache import lru_cache from backports.functools_lru_cache import lru_cache
__disable_certificate_verification_warning = 0
@lru_cache() @lru_cache()
def get_config(): def get_config():
from ..backend_config import Config from ..config import config_obj
config = Config(verbose=False) return config_obj
config.reload()
return config
class TLSv1HTTPAdapter(HTTPAdapter): class TLSv1HTTPAdapter(HTTPAdapter):
@ -42,6 +46,7 @@ def get_http_session_with_retry(
backoff_max=None, backoff_max=None,
pool_connections=None, pool_connections=None,
pool_maxsize=None): pool_maxsize=None):
global __disable_certificate_verification_warning
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)): if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
raise ValueError('Bad configuration. All retry count values must be null or int') raise ValueError('Bad configuration. All retry count values must be null or int')
@ -72,6 +77,22 @@ def get_http_session_with_retry(
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize) adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
session.mount('http://', adapter) session.mount('http://', adapter)
session.mount('https://', adapter) session.mount('https://', adapter)
# update verify host certiface
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
if not session.verify and __disable_certificate_verification_warning < 2:
# show warning
__disable_certificate_verification_warning += 1
logging.getLogger('TRAINS').warning(
msg='InsecureRequestWarning: Certificate verification is disabled! Adding '
'certificate verification is strongly advised. See: '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings')
# make sure we only do not see the warning
import urllib3
# noinspection PyBroadException
try:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
except Exception:
pass
return session return session

View File

@ -2,11 +2,13 @@
import collections import collections
import itertools import itertools
import logging import logging
from threading import RLock, Thread
from copy import copy from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse from six.moves.urllib.parse import urlparse, urlunparse
import six import six
from ...backend_api.session.defs import ENV_HOST
from ...backend_interface.task.development.worker import DevWorker from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects from ...backend_api.services import tasks, models, events, projects
@ -83,6 +85,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._parameters_allowed_types = ( self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None)) six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
) )
self._edit_lock = RLock()
if not task_id: if not task_id:
# generate a new task # generate a new task
@ -172,6 +175,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return task_id return task_id
def _update_repository(self): def _update_repository(self):
def check_package_update():
# check latest version
from ...utilities.check_updates import CheckPackageUpdates
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
self.get_logger().console(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
)
else:
self.get_logger().console(
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
latest_version[0]),
)
check_package_update_thread = Thread(target=check_package_update)
check_package_update_thread.start()
result = ScriptInfo.get(log=self.log) result = ScriptInfo.get(log=self.log)
for msg in result.warning_messages: for msg in result.warning_messages:
self.get_logger().console(msg) self.get_logger().console(msg)
@ -180,6 +201,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Since we might run asynchronously, don't use self.data (lest someone else # Since we might run asynchronously, don't use self.data (lest someone else
# overwrite it before we have a chance to call edit) # overwrite it before we have a chance to call edit)
self._edit(script=result.script) self._edit(script=result.script)
self.reload()
check_package_update_thread.join()
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training): def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s') created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
@ -335,8 +358,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _reload(self): def _reload(self):
""" Reload the task object from the backend """ """ Reload the task object from the backend """
res = self.send(tasks.GetByIdRequest(task=self.id)) with self._edit_lock:
return res.response.task res = self.send(tasks.GetByIdRequest(task=self.id))
return res.response.task
def reset(self, set_started_on_success=True): def reset(self, set_started_on_success=True):
""" Reset the task. Task will be reloaded following a successful reset. """ """ Reset the task. Task will be reloaded following a successful reset. """
@ -645,8 +669,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
parsed = parsed._replace(netloc=parsed.netloc+':8081') parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed) return urlunparse(parsed)
def _get_app_server(self): @classmethod
host = config_obj.get('api.host') def _get_api_server(cls):
return ENV_HOST.get(default=config_obj.get("api.host"))
@classmethod
def _get_app_server(cls):
host = cls._get_api_server()
if '://demoapi.' in host: if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.') return host.replace('://demoapi.', '://demoapp.')
if '://api.' in host: if '://api.' in host:
@ -657,12 +686,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return host.replace(':8008', ':8080') return host.replace(':8008', ':8080')
def _edit(self, **kwargs): def _edit(self, **kwargs):
# Since we ae using forced update, make sure he task status is valid with self._edit_lock:
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)): # Since we ae using forced update, make sure he task status is valid
raise ValueError('Task object can only be updated if created or in_progress') if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
return res return res
@classmethod @classmethod
def create_new_task(cls, session, task_entry, log=None): def create_new_task(cls, session, task_entry, log=None):

View File

@ -9,7 +9,6 @@ from collections import OrderedDict, Callable
import psutil import psutil
import six import six
from six.moves._thread import start_new_thread
from .backend_api.services import tasks, projects from .backend_api.services import tasks, projects
from .backend_interface import TaskStatusEnum from .backend_interface import TaskStatusEnum
@ -100,6 +99,8 @@ class Task(_Task):
self._dev_stop_signal = None self._dev_stop_signal = None
self._dev_mode_periodic_flag = False self._dev_mode_periodic_flag = False
self._connected_parameter_type = None self._connected_parameter_type = None
self._detect_repo_async_thread = None
self._lock = threading.RLock()
# register atexit, so that we mark the task as stopped # register atexit, so that we mark the task as stopped
self._at_exit_called = False self._at_exit_called = False
self.__register_at_exit(self._at_exit) self.__register_at_exit(self._at_exit)
@ -391,11 +392,12 @@ class Task(_Task):
# update current repository and put warning into logs # update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async: if in_dev_mode and cls.__detect_repo_async:
start_new_thread(task._update_repository, tuple()) task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.start()
else: else:
task._update_repository() task._update_repository()
# show the debug metrics page in the log, it is very convinient # show the debug metrics page in the log, it is very convenient
logger.console( logger.console(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(), task._get_app_server(),
@ -543,6 +545,16 @@ class Task(_Task):
""" """
self._dev_mode_periodic() self._dev_mode_periodic()
# wait for detection repo sync
if self._detect_repo_async_thread:
with self._lock:
if self._detect_repo_async_thread:
try:
self._detect_repo_async_thread.join()
self._detect_repo_async_thread = None
except Exception:
pass
# make sure model upload is done # make sure model upload is done
if BackendModel.get_num_results() > 0 and wait_for_uploads: if BackendModel.get_num_results() > 0 and wait_for_uploads:
BackendModel.wait_for_results() BackendModel.wait_for_results()
@ -1086,7 +1098,7 @@ class Task(_Task):
@classmethod @classmethod
def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type): def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type):
hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type) hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
# check if we have a cached task_id we can reuse # check if we have a cached task_id we can reuse
# it must be from within the last 24h and with the same project/name/type # it must be from within the last 24h and with the same project/name/type
@ -1111,7 +1123,7 @@ class Task(_Task):
@classmethod @classmethod
def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id): def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id):
hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type) hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
task_id = str(task_id) task_id = str(task_id)
# update task session cache # update task session cache

View File

@ -36,7 +36,7 @@ class AsyncManagerMixin(object):
if r.ready(): if r.ready():
continue continue
t = time.time() t = time.time()
r.wait(timeout=remaining) r.wait(timeout=remaining or 2.0)
count += 1 count += 1
if max_num_uploads is not None and max_num_uploads - count <= 0: if max_num_uploads is not None and max_num_uploads - count <= 0:
break break