mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add Task.set_credentials for cloud hosted jupyter support
This commit is contained in:
		
							parent
							
								
									cac4ac12b8
								
							
						
					
					
						commit
						7d0bf4838e
					
				| @ -1,14 +1,15 @@ | ||||
| { | ||||
|     version: 1.5 | ||||
|     host: https://demoapi.trainsai.io | ||||
|     # default https://demoapi.trainsai.io host | ||||
|     host: "" | ||||
| 
 | ||||
|     # verify host ssl certificate, set to False only if you have a very good reason | ||||
|     verify_certificate: True | ||||
| 
 | ||||
|     # default demoapi.trainsai.io credentials | ||||
|     credentials { | ||||
|         access_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" | ||||
|         secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" | ||||
|         access_key: "" | ||||
|         secret_key: "" | ||||
|     } | ||||
| 
 | ||||
|     # default version assigned to requests with no specific version. this is not expected to change | ||||
|  | ||||
| @ -3,9 +3,9 @@ import sys | ||||
| import types | ||||
| from socket import gethostname | ||||
| 
 | ||||
| import jwt | ||||
| import requests | ||||
| import six | ||||
| import jwt | ||||
| from pyhocon import ConfigTree | ||||
| from requests.auth import HTTPBasicAuth | ||||
| 
 | ||||
| @ -36,6 +36,9 @@ class Session(TokenManager): | ||||
|     _session_timeout = (5.0, None) | ||||
| 
 | ||||
|     api_version = '2.1' | ||||
|     default_host = "https://demoapi.trainsai.io" | ||||
|     default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" | ||||
|     default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" | ||||
| 
 | ||||
|     # TODO: add requests.codes.gateway_timeout once we support async commits | ||||
|     _retry_codes = [ | ||||
| @ -94,7 +97,7 @@ class Session(TokenManager): | ||||
|         self._logger = logger | ||||
| 
 | ||||
|         self.__access_key = api_key or ENV_ACCESS_KEY.get( | ||||
|             default=self.config.get("api.credentials.access_key", None) | ||||
|             default=(self.config.get("api.credentials.access_key") or self.default_key) | ||||
|         ) | ||||
|         if not self.access_key: | ||||
|             raise ValueError( | ||||
| @ -102,14 +105,14 @@ class Session(TokenManager): | ||||
|             ) | ||||
| 
 | ||||
|         self.__secret_key = secret_key or ENV_SECRET_KEY.get( | ||||
|             default=self.config.get("api.credentials.secret_key", None) | ||||
|             default=(self.config.get("api.credentials.secret_key") or self.default_secret) | ||||
|         ) | ||||
|         if not self.secret_key: | ||||
|             raise ValueError( | ||||
|                 "Missing secret_key. Please set in configuration file or pass in session init." | ||||
|             ) | ||||
| 
 | ||||
|         host = host or ENV_HOST.get(default=self.config.get("api.host")) | ||||
|         host = host or self.get_api_server_host(config=self.config) | ||||
|         if not host: | ||||
|             raise ValueError("host is required in init or config") | ||||
| 
 | ||||
| @ -386,6 +389,13 @@ class Session(TokenManager): | ||||
| 
 | ||||
|         return call_result | ||||
| 
 | ||||
|     @classmethod | ||||
|     def get_api_server_host(cls, config=None): | ||||
|         if not config: | ||||
|             from ...config import config_obj | ||||
|             config = config_obj | ||||
|         return ENV_HOST.get(default=(config.get("api.host") or cls.default_host)) | ||||
| 
 | ||||
|     def _do_refresh_token(self, old_token, exp=None): | ||||
|         """ TokenManager abstract method implementation. | ||||
|             Here we ignore the old token and simply obtain a new token. | ||||
|  | ||||
| @ -4,9 +4,10 @@ import requests.exceptions | ||||
| import six | ||||
| from ..backend_api import Session | ||||
| from ..backend_api.session import BatchRequest | ||||
| from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY | ||||
| 
 | ||||
| from ..config import config_obj | ||||
| from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY | ||||
| from ..config.defs import LOG_LEVEL_ENV_VAR | ||||
| from ..debugging import get_logger | ||||
| from ..backend_api.version import __version__ | ||||
| from .session import SendError, SessionInterface | ||||
| @ -78,8 +79,8 @@ class InterfaceBase(SessionInterface): | ||||
|                 initialize_logging=False, | ||||
|                 client='sdk-%s' % __version__, | ||||
|                 config=config_obj, | ||||
|                 api_key=API_ACCESS_KEY.get(), | ||||
|                 secret_key=API_SECRET_KEY.get(), | ||||
|                 api_key=ENV_ACCESS_KEY.get(), | ||||
|                 secret_key=ENV_SECRET_KEY.get(), | ||||
|             ) | ||||
|         return InterfaceBase._default_session | ||||
| 
 | ||||
|  | ||||
| @ -9,7 +9,6 @@ from six.moves.urllib.parse import urlparse, urlunparse | ||||
| 
 | ||||
| import six | ||||
| 
 | ||||
| from ...backend_api.session.defs import ENV_HOST | ||||
| from ...backend_interface.task.development.worker import DevWorker | ||||
| from ...backend_api import Session | ||||
| from ...backend_api.services import tasks, models, events, projects | ||||
| @ -23,7 +22,7 @@ from ..setupuploadmixin import SetupUploadMixin | ||||
| from ..util import make_message, get_or_create_project, get_single_result, \ | ||||
|     exact_match_regex | ||||
| from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \ | ||||
|     running_remotely, get_cache_dir, config_obj | ||||
|     running_remotely, get_cache_dir | ||||
| from ...debugging import get_logger | ||||
| from ...debugging.log import LoggerRoot | ||||
| from ...storage import StorageHelper | ||||
| @ -205,8 +204,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|         # overwrite it before we have a chance to call edit) | ||||
|         self._edit(script=result.script) | ||||
|         self.reload() | ||||
|         if result.script.get('requirements'): | ||||
|             self._update_requirements(result.script.get('requirements')) | ||||
|         self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '') | ||||
|         check_package_update_thread.join() | ||||
| 
 | ||||
|     def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training): | ||||
| @ -673,28 +671,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|         app_host = self._get_app_server() | ||||
|         parsed = urlparse(app_host) | ||||
|         if parsed.port: | ||||
|             parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081')) | ||||
|             parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081', 1)) | ||||
|         elif parsed.netloc.startswith('demoapp.'): | ||||
|             parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.')) | ||||
|             parsed = parsed._replace(netloc=parsed.netloc.replace('demoapp.', 'demofiles.', 1)) | ||||
|         elif parsed.netloc.startswith('app.'): | ||||
|             parsed = parsed._replace(netloc=parsed.netloc.replace('app.', 'files.', 1)) | ||||
|         else: | ||||
|             parsed = parsed._replace(netloc=parsed.netloc+':8081') | ||||
|         return urlunparse(parsed) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_api_server(cls): | ||||
|         return ENV_HOST.get(default=config_obj.get("api.host")) | ||||
|         return Session.get_api_server_host() | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _get_app_server(cls): | ||||
|         host = cls._get_api_server() | ||||
|         if '://demoapi.' in host: | ||||
|             return host.replace('://demoapi.', '://demoapp.') | ||||
|             return host.replace('://demoapi.', '://demoapp.', 1) | ||||
|         if '://api.' in host: | ||||
|             return host.replace('://api.', '://app.') | ||||
|             return host.replace('://api.', '://app.', 1) | ||||
| 
 | ||||
|         parsed = urlparse(host) | ||||
|         if parsed.port == 8008: | ||||
|             return host.replace(':8008', ':8080') | ||||
|             return host.replace(':8008', ':8080', 1) | ||||
| 
 | ||||
|     def _edit(self, **kwargs): | ||||
|         with self._edit_lock: | ||||
| @ -709,8 +709,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|     def _update_requirements(self, requirements): | ||||
|         if not isinstance(requirements, dict): | ||||
|             requirements = {'pip': requirements} | ||||
|         self.data.script.requirements = requirements | ||||
|         self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) | ||||
|         # protection, Old API might not support it | ||||
|         try: | ||||
|             self.data.script.requirements = requirements | ||||
|             self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) | ||||
|         except Exception: | ||||
|             pass | ||||
| 
 | ||||
|     def _update_script(self, script): | ||||
|         self.data.script = script | ||||
|  | ||||
| @ -57,29 +57,29 @@ def main(): | ||||
|     if parsed_host.port == 8080: | ||||
|         # this is a docker 8080 is the web address, we need the api address, it is 8008 | ||||
|         print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008') | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path | ||||
|     elif parsed_host.netloc.startswith('demoapp.'): | ||||
|         print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format( | ||||
|             parsed_host.netloc)) | ||||
|         # this is our demo server | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.', 1) + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path | ||||
|     elif parsed_host.netloc.startswith('app.'): | ||||
|         print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format( | ||||
|             parsed_host.netloc)) | ||||
|         # this is our application server | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.', 1) + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path | ||||
|     elif parsed_host.port == 8008: | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080', 1) + parsed_host.path | ||||
|     elif parsed_host.netloc.startswith('demoapi.'): | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.', 1) + parsed_host.path | ||||
|     elif parsed_host.netloc.startswith('api.'): | ||||
|         api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path | ||||
|         web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.', 1) + parsed_host.path | ||||
|     else: | ||||
|         api_host = None | ||||
|         web_host = None | ||||
|  | ||||
| @ -11,6 +11,7 @@ import psutil | ||||
| import six | ||||
| 
 | ||||
| from .backend_api.services import tasks, projects | ||||
| from .backend_api.session.session import Session | ||||
| from .backend_interface.model import Model as BackendModel | ||||
| from .backend_interface.task import Task as _Task | ||||
| from .backend_interface.task.args import _Arguments | ||||
| @ -25,6 +26,7 @@ from .errors import UsageError | ||||
| from .logger import Logger | ||||
| from .model import InputModel, OutputModel, ARCHIVED_TAG | ||||
| from .task_parameters import TaskParameters | ||||
| from .binding.environ_bind import EnvironmentBind | ||||
| from .binding.absl_bind import PatchAbsl | ||||
| from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ | ||||
|     argparser_update_currenttask | ||||
| @ -212,6 +214,7 @@ class Task(_Task): | ||||
|             Task.__main_task = task | ||||
|             # Patch argparse to be aware of the current task | ||||
|             argparser_update_currenttask(Task.__main_task) | ||||
|             EnvironmentBind.update_current_task(Task.__main_task) | ||||
|             if auto_connect_frameworks: | ||||
|                 PatchedMatplotlib.update_current_task(Task.__main_task) | ||||
|                 PatchAbsl.update_current_task(Task.__main_task) | ||||
| @ -687,6 +690,26 @@ class Task(_Task): | ||||
|         self.data.last_iteration = int(last_iteration) | ||||
|         self._edit(last_iteration=self.data.last_iteration) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def set_credentials(cls, host=None, key=None, secret=None): | ||||
|         """ | ||||
|         Set new default TRAINS-server host and credentials | ||||
|         These configurations will be overridden by wither OS environment variables or trains.conf configuration file | ||||
|         Notice: credentials needs to be set prior to Task initialization | ||||
|         :param host: host url, example: host='http://localhost:8008' | ||||
|         :type  host: str | ||||
|         :param key: user key/secret pair, example: key='thisisakey123' | ||||
|         :type  key: str | ||||
|         :param secret: user key/secret pair, example: secret='thisisseceret123' | ||||
|         :type  secret: str | ||||
|         """ | ||||
|         if host: | ||||
|             Session.default_host = host | ||||
|         if key: | ||||
|             Session.default_key = key | ||||
|         if secret: | ||||
|             Session.default_secret = secret | ||||
| 
 | ||||
|     def _connect_output_model(self, model): | ||||
|         assert isinstance(model, OutputModel) | ||||
|         model.connect(self) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai