From 036d7b6ef2ab3f6fb00746861f4dc5b11ef24750 Mon Sep 17 00:00:00 2001 From: clearml <> Date: Sun, 17 Nov 2024 10:51:18 +0200 Subject: [PATCH] Add TCP protocol support to Task.request_external_endpoint() --- clearml/task.py | 162 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 121 insertions(+), 41 deletions(-) diff --git a/clearml/task.py b/clearml/task.py index fede03d5..c4a87e18 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -185,6 +185,11 @@ class Task(_Task): _launch_multi_node_section = "launch_multi_node" _launch_multi_node_instance_tag = "multi_node_instance" + _external_endpoint_port_map = {"http": "_PORT", "tcp": "external_tcp_port"} + _external_endpoint_address_map = {"http": "_ADDRESS", "tcp": "external_address"} + _external_endpoint_service_map = {"http": "EXTERNAL", "tcp": "EXTERNAL_TCP"} + _external_endpoint_internal_port_map = {"http": "_PORT", "tcp": "upstream_task_port"} + class _ConnectedParametersType(object): argparse = "argument_parser" dictionary = "dictionary" @@ -218,6 +223,7 @@ class Task(_Task): self._resource_monitor = None self._calling_filename = None self._remote_functions_generated = {} + self._external_endpoint_ports = {} # register atexit, so that we mark the task as stopped self._at_exit_called = False @@ -850,7 +856,7 @@ class Task(_Task): Request an external endpoint for an application :param port: Port the application is listening to - :param protocol: As of now, only `http` is supported + :param protocol: `http` or `tcp` :param wait: If True, wait for the endpoint to be assigned :param wait_interval_seconds: The poll frequency when waiting for the endpoint :param wait_timeout_seconds: If this timeout is exceeded while waiting for the endpoint, @@ -865,67 +871,128 @@ class Task(_Task): - protocol - the protocol used by the endpoint """ Session.verify_feature_set("advanced") - if not getattr(self, "_external_endpoint_port", None): + if protocol not in self._external_endpoint_port_map.keys(): + raise ValueError("Invalid protocol: {}".format(protocol)) + + if not self._external_endpoint_ports.get(protocol): self.reload() - assigned_port = self._get_runtime_properties().get("_PORT") - if assigned_port: - self._external_endpoint_port = assigned_port - if getattr(self, "_external_endpoint_port", None): - if self._external_endpoint_port != port: # noqa + internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol]) + if internal_port: + self._external_endpoint_ports[protocol] = internal_port + if self._external_endpoint_ports.get(protocol): + if self._external_endpoint_ports.get(protocol) != port: # noqa raise ValueError( - "Only one endpoint can be requested at the moment. Port already exposed is: {}".format( - self._external_endpoint_port + "Only one endpoint per protocol can be requested at the moment. Port already exposed is: {}".format( + self._external_endpoint_ports.get(protocol) ) ) return # noinspection PyProtectedMember self._set_runtime_properties( - {"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": port} + { + "_SERVICE": self._external_endpoint_service_map[protocol], + self._external_endpoint_address_map[protocol]: get_private_ip(), + self._external_endpoint_port_map[protocol]: port, + } ) self.set_system_tags((self.get_system_tags() or []) + ["external_service"]) - self._external_endpoint_port = port + self._external_endpoint_ports[protocol] = port if wait: - return self.wait_for_external_endpoint(wait_interval_seconds=wait_interval_seconds) + return self.wait_for_external_endpoint(wait_interval_seconds=wait_interval_seconds, protocol=protocol) return None - def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0): - # type: (float) -> Optional[Dict] + def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http"): + # type: (float, float, Optional[str]) -> Union[Optional[Dict], List[Optional[Dict]]] """ Wait for an external endpoint to be assigned :param wait_interval_seconds: The poll frequency when waiting for the endpoint :param wait_timeout_seconds: If this timeout is exceeded while waiting for the endpoint, the method will no longer wait + :param protocol: `http` or `tcp`. Wait for an endpoint to be assigned based on the protocol. + If None, wait for all supported protocols :return: If no endpoint could be found while waiting, this mehtod returns None. - Otherwise, it returns a dictionary containing the following values: + If a protocol has been specified, it returns a dictionary containing the following values: - endpoint - raw endpoint. One might need to authenticate in order to use this endpoint - browser_endpoint - endpoint to be used in browser. Authentication will be handled via the browser - port - the port exposed by the application - protocol - the protocol used by the endpoint + If not protocol is specified, it returns a list of dictionaries containing the values above, + for each protocol requested and waited """ Session.verify_feature_set("advanced") - if not getattr(self, "_external_endpoint_port", None): - LoggerRoot.get_base_logger().warning("No external endpoints have been requested") + if protocol: + return self._wait_for_external_endpoint( + wait_interval_seconds=wait_interval_seconds, + wait_timeout_seconds=wait_timeout_seconds, + protocol=protocol, + warn=True + ) + results = [] + protocols = ["http", "tcp"] + waited_protocols = [] + for protocol_ in protocols: + start_time = time.time() + result = self._wait_for_external_endpoint( + wait_interval_seconds=wait_interval_seconds, + wait_timeout_seconds=wait_timeout_seconds, + protocol=protocol_, + warn=False, + ) + elapsed = time.time() - start_time + if result: + results.append(result) + wait_timeout_seconds -= elapsed + if wait_timeout_seconds > 0 or result: + waited_protocols.append(protocol_) + unwaited_protocols = [p for p in protocols if p not in waited_protocols] + if wait_timeout_seconds <= 0 and unwaited_protocols: + LoggerRoot.get_base_logger().warning( + "Timeout exceeded while waiting for {} endpoint(s)".format(",".join(unwaited_protocols)) + ) + return results + + def _wait_for_external_endpoint( + self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http", warn=True + ): + if not self._external_endpoint_ports.get(protocol): + self.reload() + internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol]) + if internal_port: + self._external_endpoint_ports[protocol] = internal_port + if not self._external_endpoint_ports.get(protocol): + if warn: + LoggerRoot.get_base_logger().warning("No external {} endpoints have been requested".format(protocol)) return None start_time = time.time() while True: self.reload() - # noinspection PyProtectedMember runtime_props = self._get_runtime_properties() - endpoint = runtime_props.get("endpoint") - browser_endpoint = runtime_props.get("browser_endpoint") - if not getattr(self, "_external_endpoint_port", None): - self._external_endpoint_port = runtime_props.get("_PORT") + endpoint, browser_endpoint = None, None + if protocol == "http": + endpoint = runtime_props.get("endpoint") + browser_endpoint = runtime_props.get("browser_endpoint") + elif protocol == "tcp": + health_check = runtime_props.get("upstream_task_port") + if health_check: + endpoint = ( + runtime_props.get(self._external_endpoint_address_map[protocol]) + + ":" + + str(runtime_props.get(self._external_endpoint_port_map[protocol])) + ) if endpoint or browser_endpoint: return { "endpoint": endpoint, "browser_endpoint": browser_endpoint, - "port": self._external_endpoint_port, - "protocol": "http", + "port": self._external_endpoint_ports[protocol], + "protocol": protocol, } if time.time() >= start_time + wait_timeout_seconds: - LoggerRoot.get_base_logger().warning("Timeout exceeded while waiting for endpoint") + if warn: + LoggerRoot.get_base_logger().warning( + "Timeout exceeded while waiting for {} endpoint".format(protocol) + ) return None time.sleep(wait_interval_seconds) @@ -941,23 +1008,36 @@ class Task(_Task): - protocol - the protocol used by the endpoint """ Session.verify_feature_set("advanced") - if not getattr(self, "_external_endpoint_port", None): - self.reload() - self._external_endpoint_port = self._get_runtime_properties().get("_PORT") - if not getattr(self, "_external_endpoint_port", None): - LoggerRoot.get_base_logger().warning("No external endpoints have been requested") - return [] runtime_props = self._get_runtime_properties() - endpoint = runtime_props.get("endpoint") - browser_endpoint = runtime_props.get("browser_endpoint") - return [ - { - "endpoint": endpoint, - "browser_endpoint": browser_endpoint, - "port": self._external_endpoint_port, - "protocol": "http", - } - ] + results = [] + for protocol in ["http", "tcp"]: + internal_port = runtime_props.get(self._external_endpoint_internal_port_map[protocol]) + if internal_port: + self._external_endpoint_ports[protocol] = internal_port + else: + continue + endpoint, browser_endpoint = None, None + if protocol == "http": + endpoint = runtime_props.get("endpoint") + browser_endpoint = runtime_props.get("browser_endpoint") + elif protocol == "tcp": + health_check = runtime_props.get("upstream_task_port") + if health_check: + endpoint = ( + runtime_props.get(self._external_endpoint_address_map[protocol]) + + ":" + + str(runtime_props.get(self._external_endpoint_port_map[protocol])) + ) + if endpoint or browser_endpoint: + results.append( + { + "endpoint": endpoint, + "browser_endpoint": browser_endpoint, + "port": internal_port, + "protocol": protocol, + } + ) + return results @classmethod def create(