mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix default method might get updated after default argument was initialized
Fix default api method does not work when set in configuration
This commit is contained in:
		
							parent
							
								
									164169b73a
								
							
						
					
					
						commit
						443c6dc814
					
				| @ -8,12 +8,15 @@ ENV_FILES_HOST = EnvEntry("CLEARML_FILES_HOST", "TRAINS_FILES_HOST") | ||||
| ENV_ACCESS_KEY = EnvEntry("CLEARML_API_ACCESS_KEY", "TRAINS_API_ACCESS_KEY") | ||||
| ENV_SECRET_KEY = EnvEntry("CLEARML_API_SECRET_KEY", "TRAINS_API_SECRET_KEY") | ||||
| ENV_AUTH_TOKEN = EnvEntry("CLEARML_AUTH_TOKEN") | ||||
| ENV_VERBOSE = EnvEntry("CLEARML_API_VERBOSE", "TRAINS_API_VERBOSE", type=bool, default=False) | ||||
| ENV_VERBOSE = EnvEntry( | ||||
|     "CLEARML_API_VERBOSE", "TRAINS_API_VERBOSE", converter=safe_text_to_bool, type=bool, default=False | ||||
| ) | ||||
| ENV_HOST_VERIFY_CERT = EnvEntry("CLEARML_API_HOST_VERIFY_CERT", "TRAINS_API_HOST_VERIFY_CERT", | ||||
|                                 type=bool, default=True) | ||||
| ENV_OFFLINE_MODE = EnvEntry("CLEARML_OFFLINE_MODE", "TRAINS_OFFLINE_MODE", type=bool, converter=safe_text_to_bool) | ||||
| ENV_CLEARML_NO_DEFAULT_SERVER = EnvEntry("CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", | ||||
|                                          converter=safe_text_to_bool, type=bool, default=True) | ||||
| ENV_CLEARML_NO_DEFAULT_SERVER = EnvEntry( | ||||
|     "CLEARML_NO_DEFAULT_SERVER", "TRAINS_NO_DEFAULT_SERVER", converter=safe_text_to_bool, type=bool, default=True | ||||
| ) | ||||
| ENV_DISABLE_VAULT_SUPPORT = EnvEntry('CLEARML_DISABLE_VAULT_SUPPORT', type=bool) | ||||
| ENV_ENABLE_ENV_CONFIG_SECTION = EnvEntry('CLEARML_ENABLE_ENV_CONFIG_SECTION', type=bool) | ||||
| ENV_ENABLE_FILES_CONFIG_SECTION = EnvEntry('CLEARML_ENABLE_FILES_CONFIG_SECTION', type=bool) | ||||
|  | ||||
| @ -187,6 +187,7 @@ class Session(TokenManager): | ||||
|             "api.http.retries", ConfigTree()).as_plain_ordered_dict() | ||||
| 
 | ||||
|         http_retries_config["status_forcelist"] = self._get_retry_codes() | ||||
|         http_retries_config["config"] = self.config | ||||
|         self.__http_session = get_http_session_with_retry(**http_retries_config) | ||||
|         self.__http_session.write_timeout = self._write_session_timeout | ||||
|         self.__http_session.request_size_threshold = self._write_session_data_size | ||||
| @ -237,6 +238,18 @@ class Session(TokenManager): | ||||
| 
 | ||||
|         self._apply_config_sections(local_logger) | ||||
| 
 | ||||
|         self._update_default_api_method() | ||||
| 
 | ||||
|     def _update_default_api_method(self): | ||||
|         if not ENV_API_DEFAULT_REQ_METHOD.get(default=None) and self.config.get("api.http.default_method", None): | ||||
|             def_method = str(self.config.get("api.http.default_method", None)).strip() | ||||
|             if def_method.upper() not in ("GET", "POST", "PUT"): | ||||
|                 raise ValueError( | ||||
|                     "api.http.default_method variable must be 'get' or 'post' (any case is allowed)." | ||||
|                 ) | ||||
|             Request.def_method = def_method | ||||
|             Request._method = Request.def_method | ||||
| 
 | ||||
|     def _get_retry_codes(self): | ||||
|         # type: () -> List[int] | ||||
|         retry_codes = set(self._retry_codes) | ||||
| @ -278,7 +291,8 @@ class Session(TokenManager): | ||||
| 
 | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             res = self.send_request("users", "get_vaults", json={"enabled": True, "types": ["config"]}) | ||||
|             # Use params and not data/json otherwise payload might be dropped if we're using GET with a strict firewall | ||||
|             res = self.send_request("users", "get_vaults", params="enabled=true&types=config&types=config") | ||||
|             if res.ok: | ||||
|                 vaults = res.json().get("data", {}).get("vaults", []) | ||||
|                 data = list(filter(None, map(parse, vaults))) | ||||
| @ -312,12 +326,13 @@ class Session(TokenManager): | ||||
|         service, | ||||
|         action, | ||||
|         version=None, | ||||
|         method=Request.def_method, | ||||
|         method=None, | ||||
|         headers=None, | ||||
|         auth=None, | ||||
|         data=None, | ||||
|         json=None, | ||||
|         refresh_token_if_unauthorized=True, | ||||
|         params=None, | ||||
|     ): | ||||
|         """ Internal implementation for making a raw API request. | ||||
|             - Constructs the api endpoint name | ||||
| @ -331,6 +346,9 @@ class Session(TokenManager): | ||||
|         if self._offline_mode: | ||||
|             return None | ||||
| 
 | ||||
|         if not method: | ||||
|             method = Request.def_method | ||||
| 
 | ||||
|         res = None | ||||
|         host = self.host | ||||
|         headers = headers.copy() if headers else {} | ||||
| @ -401,11 +419,12 @@ class Session(TokenManager): | ||||
|         service, | ||||
|         action, | ||||
|         version=None, | ||||
|         method=Request.def_method, | ||||
|         method=None, | ||||
|         headers=None, | ||||
|         data=None, | ||||
|         json=None, | ||||
|         async_enable=False, | ||||
|         params=None, | ||||
|     ): | ||||
|         """ | ||||
|         Send a raw API request. | ||||
| @ -420,6 +439,8 @@ class Session(TokenManager): | ||||
|         :param async_enable: whether request is asynchronous | ||||
|         :return: requests Response instance | ||||
|         """ | ||||
|         if not method: | ||||
|             method = Request.def_method | ||||
|         headers = self.add_auth_headers( | ||||
|             headers.copy() if headers else {} | ||||
|         ) | ||||
| @ -434,6 +455,7 @@ class Session(TokenManager): | ||||
|             headers=headers, | ||||
|             data=data, | ||||
|             json=json, | ||||
|             params=params, | ||||
|         ) | ||||
| 
 | ||||
|     def send_request_batch( | ||||
| @ -444,7 +466,7 @@ class Session(TokenManager): | ||||
|         headers=None, | ||||
|         data=None, | ||||
|         json=None, | ||||
|         method=Request.def_method, | ||||
|         method=None, | ||||
|     ): | ||||
|         """ | ||||
|         Send a raw batch API request. Batch requests always use application/json-lines content type. | ||||
| @ -469,6 +491,9 @@ class Session(TokenManager): | ||||
|             # Missing data (data or json), batch requests are meaningless without it. | ||||
|             return None | ||||
| 
 | ||||
|         if not method: | ||||
|             method = Request.def_method | ||||
| 
 | ||||
|         headers = headers.copy() if headers else {} | ||||
|         headers["Content-Type"] = "application/json-lines" | ||||
| 
 | ||||
| @ -677,7 +702,7 @@ class Session(TokenManager): | ||||
|                             pass | ||||
|                     cls.max_api_version = cls.api_version = cls._offline_default_version | ||||
|             else: | ||||
|                 # if the requested version is lower then the minium we support, | ||||
|                 # if the requested version is lower then the minimum we support, | ||||
|                 # no need to actually check what the server has, we assume it must have at least our version. | ||||
|                 if cls._version_tuple(cls.api_version) >= cls._version_tuple(str(min_api_version)): | ||||
|                    return True | ||||
| @ -736,15 +761,14 @@ class Session(TokenManager): | ||||
|         auth = HTTPBasicAuth(self.access_key, self.secret_key) if self.access_key and self.secret_key else None | ||||
|         res = None | ||||
|         try: | ||||
|             data = {"expiration_sec": exp} if exp else {} | ||||
|             res = self._send_request( | ||||
|                 method=Request.def_method, | ||||
|                 service="auth", | ||||
|                 action="login", | ||||
|                 auth=auth, | ||||
|                 json=data, | ||||
|                 headers=headers, | ||||
|                 refresh_token_if_unauthorized=False, | ||||
|                 params={"expiration_sec": exp} if exp else {}, | ||||
|             ) | ||||
|             try: | ||||
|                 resp = res.json() | ||||
|  | ||||
| @ -95,7 +95,9 @@ def get_http_session_with_retry( | ||||
|         backoff_factor=0, | ||||
|         backoff_max=None, | ||||
|         pool_connections=None, | ||||
|         pool_maxsize=None): | ||||
|         pool_maxsize=None, | ||||
|         config=None | ||||
| ): | ||||
|     global __disable_certificate_verification_warning | ||||
|     if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)): | ||||
|         raise ValueError('Bad configuration. All retry count values must be null or int') | ||||
| @ -103,16 +105,18 @@ def get_http_session_with_retry( | ||||
|     if status_forcelist and not all(isinstance(x, int) for x in status_forcelist): | ||||
|         raise ValueError('Bad configuration. Retry status_forcelist must be null or list of ints') | ||||
| 
 | ||||
|     config = config or get_config() | ||||
| 
 | ||||
|     pool_maxsize = ( | ||||
|         pool_maxsize | ||||
|         if pool_maxsize is not None | ||||
|         else get_config().get('api.http.pool_maxsize', 512) | ||||
|         else config.get('api.http.pool_maxsize', 512) | ||||
|     ) | ||||
| 
 | ||||
|     pool_connections = ( | ||||
|         pool_connections | ||||
|         if pool_connections is not None | ||||
|         else get_config().get('api.http.pool_connections', 512) | ||||
|         else config.get('api.http.pool_connections', 512) | ||||
|     ) | ||||
| 
 | ||||
|     session = SessionWithTimeout() | ||||
| @ -135,7 +139,7 @@ def get_http_session_with_retry( | ||||
|     session.mount('http://', adapter) | ||||
|     session.mount('https://', adapter) | ||||
|     # update verify host certificate | ||||
|     session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True)) | ||||
|     session.verify = ENV_HOST_VERIFY_CERT.get(default=config.get('api.verify_certificate', True)) | ||||
|     if not session.verify and __disable_certificate_verification_warning < 2: | ||||
|         # show warning | ||||
|         __disable_certificate_verification_warning += 1 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai