mirror of
https://github.com/clearml/clearml
synced 2025-03-09 13:30:42 +00:00
Add api.verify_certificate to allow non-verified ssl connection (enterprise firewall mitm scenarios), by defaults only secured connections are allowed.
Add latest package check, as long as we are in pre-release.
This commit is contained in:
parent
fdc767c1c5
commit
675dc32528
@ -118,7 +118,7 @@ taks = Task.init(project_name, task_name, output_uri="gs://bucket-name/folder")
|
||||
```
|
||||
|
||||
**NOTE:** These require configuring the storage credentials in `~/trains.conf`.
|
||||
For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L51).
|
||||
For a more detailed example, see [here](https://github.com/allegroai/trains/blob/master/docs/trains.conf#L55).
|
||||
|
||||
|
||||
## I am training multiple models at the same time, but I only see one of them. What happened? <a name="only-last-model-appears"></a>
|
||||
|
@ -2,8 +2,12 @@
|
||||
api {
|
||||
# Notice: 'host' is the api server (default port 8008), not the web server.
|
||||
host: http://localhost:8008
|
||||
|
||||
# Credentials are generated in the webapp, http://localhost:8080/admin
|
||||
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
|
||||
|
||||
# verify host ssl certificate, set to False only if you have a very good reason
|
||||
verify_certificate: True
|
||||
}
|
||||
sdk {
|
||||
# TRAINS - default SDK configuration
|
||||
|
@ -17,16 +17,16 @@ jsonschema>=2.6.0
|
||||
numpy>=1.10
|
||||
opencv-python>=3.2.0.8
|
||||
pathlib2>=2.3.0
|
||||
plotly>=3.9.0
|
||||
psutil>=3.4.2
|
||||
pyhocon>=0.3.38
|
||||
python-dateutil>=2.6.1
|
||||
pyjwt>=1.6.4
|
||||
PyYAML>=3.12
|
||||
requests-file>=1.4.2
|
||||
requests>=2.18.4
|
||||
six>=1.11.0
|
||||
tqdm>=4.19.5
|
||||
typing>=3.6.4
|
||||
urllib3>=1.22
|
||||
watchdog>=0.8.0
|
||||
pyjwt>=1.6.4
|
||||
plotly>=3.9.0
|
||||
typing>=3.6.4
|
||||
|
@ -2,6 +2,9 @@
|
||||
version: 1.5
|
||||
host: https://demoapi.trainsai.io
|
||||
|
||||
# verify host ssl certificate, set to False only if you have a very good reason
|
||||
verify_certificate: True
|
||||
|
||||
# default demoapi.trainsai.io credentials
|
||||
credentials {
|
||||
access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
|
@ -5,3 +5,4 @@ ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
|
||||
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
|
||||
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
@ -8,6 +9,8 @@ from urllib3.util import Retry
|
||||
from urllib3 import PoolManager
|
||||
import six
|
||||
|
||||
from .session.defs import ENV_HOST_VERIFY_CERT
|
||||
|
||||
if six.PY3:
|
||||
from functools import lru_cache
|
||||
elif six.PY2:
|
||||
@ -15,12 +18,13 @@ elif six.PY2:
|
||||
from backports.functools_lru_cache import lru_cache
|
||||
|
||||
|
||||
__disable_certificate_verification_warning = 0
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_config():
|
||||
from ..backend_config import Config
|
||||
config = Config(verbose=False)
|
||||
config.reload()
|
||||
return config
|
||||
from ..config import config_obj
|
||||
return config_obj
|
||||
|
||||
|
||||
class TLSv1HTTPAdapter(HTTPAdapter):
|
||||
@ -42,6 +46,7 @@ def get_http_session_with_retry(
|
||||
backoff_max=None,
|
||||
pool_connections=None,
|
||||
pool_maxsize=None):
|
||||
global __disable_certificate_verification_warning
|
||||
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
|
||||
raise ValueError('Bad configuration. All retry count values must be null or int')
|
||||
|
||||
@ -72,6 +77,22 @@ def get_http_session_with_retry(
|
||||
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
# update verify host certiface
|
||||
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
|
||||
if not session.verify and __disable_certificate_verification_warning < 2:
|
||||
# show warning
|
||||
__disable_certificate_verification_warning += 1
|
||||
logging.getLogger('TRAINS').warning(
|
||||
msg='InsecureRequestWarning: Certificate verification is disabled! Adding '
|
||||
'certificate verification is strongly advised. See: '
|
||||
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings')
|
||||
# make sure we only do not see the warning
|
||||
import urllib3
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
except Exception:
|
||||
pass
|
||||
return session
|
||||
|
||||
|
||||
|
@ -2,11 +2,13 @@
|
||||
import collections
|
||||
import itertools
|
||||
import logging
|
||||
from threading import RLock, Thread
|
||||
from copy import copy
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
import six
|
||||
|
||||
from ...backend_api.session.defs import ENV_HOST
|
||||
from ...backend_interface.task.development.worker import DevWorker
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import tasks, models, events, projects
|
||||
@ -83,6 +85,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._parameters_allowed_types = (
|
||||
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
|
||||
)
|
||||
self._edit_lock = RLock()
|
||||
|
||||
if not task_id:
|
||||
# generate a new task
|
||||
@ -172,6 +175,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return task_id
|
||||
|
||||
def _update_repository(self):
|
||||
def check_package_update():
|
||||
# check latest version
|
||||
from ...utilities.check_updates import CheckPackageUpdates
|
||||
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
|
||||
if latest_version:
|
||||
if not latest_version[1]:
|
||||
self.get_logger().console(
|
||||
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
|
||||
latest_version[0]),
|
||||
)
|
||||
else:
|
||||
self.get_logger().console(
|
||||
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
|
||||
latest_version[0]),
|
||||
)
|
||||
|
||||
check_package_update_thread = Thread(target=check_package_update)
|
||||
check_package_update_thread.start()
|
||||
result = ScriptInfo.get(log=self.log)
|
||||
for msg in result.warning_messages:
|
||||
self.get_logger().console(msg)
|
||||
@ -180,6 +201,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# Since we might run asynchronously, don't use self.data (lest someone else
|
||||
# overwrite it before we have a chance to call edit)
|
||||
self._edit(script=result.script)
|
||||
self.reload()
|
||||
check_package_update_thread.join()
|
||||
|
||||
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training):
|
||||
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
|
||||
@ -335,8 +358,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
|
||||
def _reload(self):
|
||||
""" Reload the task object from the backend """
|
||||
res = self.send(tasks.GetByIdRequest(task=self.id))
|
||||
return res.response.task
|
||||
with self._edit_lock:
|
||||
res = self.send(tasks.GetByIdRequest(task=self.id))
|
||||
return res.response.task
|
||||
|
||||
def reset(self, set_started_on_success=True):
|
||||
""" Reset the task. Task will be reloaded following a successful reset. """
|
||||
@ -645,8 +669,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
||||
return urlunparse(parsed)
|
||||
|
||||
def _get_app_server(self):
|
||||
host = config_obj.get('api.host')
|
||||
@classmethod
|
||||
def _get_api_server(cls):
|
||||
return ENV_HOST.get(default=config_obj.get("api.host"))
|
||||
|
||||
@classmethod
|
||||
def _get_app_server(cls):
|
||||
host = cls._get_api_server()
|
||||
if '://demoapi.' in host:
|
||||
return host.replace('://demoapi.', '://demoapp.')
|
||||
if '://api.' in host:
|
||||
@ -657,12 +686,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return host.replace(':8008', ':8080')
|
||||
|
||||
def _edit(self, **kwargs):
|
||||
# Since we ae using forced update, make sure he task status is valid
|
||||
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
|
||||
raise ValueError('Task object can only be updated if created or in_progress')
|
||||
with self._edit_lock:
|
||||
# Since we ae using forced update, make sure he task status is valid
|
||||
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
|
||||
raise ValueError('Task object can only be updated if created or in_progress')
|
||||
|
||||
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
|
||||
return res
|
||||
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def create_new_task(cls, session, task_entry, log=None):
|
||||
|
@ -9,7 +9,6 @@ from collections import OrderedDict, Callable
|
||||
|
||||
import psutil
|
||||
import six
|
||||
from six.moves._thread import start_new_thread
|
||||
|
||||
from .backend_api.services import tasks, projects
|
||||
from .backend_interface import TaskStatusEnum
|
||||
@ -100,6 +99,8 @@ class Task(_Task):
|
||||
self._dev_stop_signal = None
|
||||
self._dev_mode_periodic_flag = False
|
||||
self._connected_parameter_type = None
|
||||
self._detect_repo_async_thread = None
|
||||
self._lock = threading.RLock()
|
||||
# register atexit, so that we mark the task as stopped
|
||||
self._at_exit_called = False
|
||||
self.__register_at_exit(self._at_exit)
|
||||
@ -391,11 +392,12 @@ class Task(_Task):
|
||||
|
||||
# update current repository and put warning into logs
|
||||
if in_dev_mode and cls.__detect_repo_async:
|
||||
start_new_thread(task._update_repository, tuple())
|
||||
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
|
||||
task._detect_repo_async_thread.start()
|
||||
else:
|
||||
task._update_repository()
|
||||
|
||||
# show the debug metrics page in the log, it is very convinient
|
||||
# show the debug metrics page in the log, it is very convenient
|
||||
logger.console(
|
||||
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
|
||||
task._get_app_server(),
|
||||
@ -543,6 +545,16 @@ class Task(_Task):
|
||||
"""
|
||||
self._dev_mode_periodic()
|
||||
|
||||
# wait for detection repo sync
|
||||
if self._detect_repo_async_thread:
|
||||
with self._lock:
|
||||
if self._detect_repo_async_thread:
|
||||
try:
|
||||
self._detect_repo_async_thread.join()
|
||||
self._detect_repo_async_thread = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# make sure model upload is done
|
||||
if BackendModel.get_num_results() > 0 and wait_for_uploads:
|
||||
BackendModel.wait_for_results()
|
||||
@ -1086,7 +1098,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type):
|
||||
hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type)
|
||||
hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
|
||||
|
||||
# check if we have a cached task_id we can reuse
|
||||
# it must be from within the last 24h and with the same project/name/type
|
||||
@ -1111,7 +1123,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id):
|
||||
hash_key = cls.__get_hash_key(default_project_name, default_task_name, default_task_type)
|
||||
hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
|
||||
|
||||
task_id = str(task_id)
|
||||
# update task session cache
|
||||
|
@ -36,7 +36,7 @@ class AsyncManagerMixin(object):
|
||||
if r.ready():
|
||||
continue
|
||||
t = time.time()
|
||||
r.wait(timeout=remaining)
|
||||
r.wait(timeout=remaining or 2.0)
|
||||
count += 1
|
||||
if max_num_uploads is not None and max_num_uploads - count <= 0:
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user