Add TCP protocol support to Task.request_external_endpoint()

This commit is contained in:
clearml 2024-11-17 10:51:18 +02:00
parent cd6d579944
commit 036d7b6ef2

View File

@ -185,6 +185,11 @@ class Task(_Task):
_launch_multi_node_section = "launch_multi_node"
_launch_multi_node_instance_tag = "multi_node_instance"
_external_endpoint_port_map = {"http": "_PORT", "tcp": "external_tcp_port"}
_external_endpoint_address_map = {"http": "_ADDRESS", "tcp": "external_address"}
_external_endpoint_service_map = {"http": "EXTERNAL", "tcp": "EXTERNAL_TCP"}
_external_endpoint_internal_port_map = {"http": "_PORT", "tcp": "upstream_task_port"}
class _ConnectedParametersType(object):
argparse = "argument_parser"
dictionary = "dictionary"
@ -218,6 +223,7 @@ class Task(_Task):
self._resource_monitor = None
self._calling_filename = None
self._remote_functions_generated = {}
self._external_endpoint_ports = {}
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
@ -850,7 +856,7 @@ class Task(_Task):
Request an external endpoint for an application
:param port: Port the application is listening to
:param protocol: As of now, only `http` is supported
:param protocol: `http` or `tcp`
:param wait: If True, wait for the endpoint to be assigned
:param wait_interval_seconds: The poll frequency when waiting for the endpoint
:param wait_timeout_seconds: If this timeout is exceeded while waiting for the endpoint,
@ -865,67 +871,128 @@ class Task(_Task):
- protocol - the protocol used by the endpoint
"""
Session.verify_feature_set("advanced")
if not getattr(self, "_external_endpoint_port", None):
if protocol not in self._external_endpoint_port_map.keys():
raise ValueError("Invalid protocol: {}".format(protocol))
if not self._external_endpoint_ports.get(protocol):
self.reload()
assigned_port = self._get_runtime_properties().get("_PORT")
if assigned_port:
self._external_endpoint_port = assigned_port
if getattr(self, "_external_endpoint_port", None):
if self._external_endpoint_port != port: # noqa
internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol])
if internal_port:
self._external_endpoint_ports[protocol] = internal_port
if self._external_endpoint_ports.get(protocol):
if self._external_endpoint_ports.get(protocol) != port: # noqa
raise ValueError(
"Only one endpoint can be requested at the moment. Port already exposed is: {}".format(
self._external_endpoint_port
"Only one endpoint per protocol can be requested at the moment. Port already exposed is: {}".format(
self._external_endpoint_ports.get(protocol)
)
)
return
# noinspection PyProtectedMember
self._set_runtime_properties(
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": port}
{
"_SERVICE": self._external_endpoint_service_map[protocol],
self._external_endpoint_address_map[protocol]: get_private_ip(),
self._external_endpoint_port_map[protocol]: port,
}
)
self.set_system_tags((self.get_system_tags() or []) + ["external_service"])
self._external_endpoint_port = port
self._external_endpoint_ports[protocol] = port
if wait:
return self.wait_for_external_endpoint(wait_interval_seconds=wait_interval_seconds)
return self.wait_for_external_endpoint(wait_interval_seconds=wait_interval_seconds, protocol=protocol)
return None
def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0):
# type: (float) -> Optional[Dict]
def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http"):
# type: (float, float, Optional[str]) -> Union[Optional[Dict], List[Optional[Dict]]]
"""
Wait for an external endpoint to be assigned
:param wait_interval_seconds: The poll frequency when waiting for the endpoint
:param wait_timeout_seconds: If this timeout is exceeded while waiting for the endpoint,
the method will no longer wait
:param protocol: `http` or `tcp`. Wait for an endpoint to be assigned based on the protocol.
If None, wait for all supported protocols
:return: If no endpoint could be found while waiting, this mehtod returns None.
Otherwise, it returns a dictionary containing the following values:
If a protocol has been specified, it returns a dictionary containing the following values:
- endpoint - raw endpoint. One might need to authenticate in order to use this endpoint
- browser_endpoint - endpoint to be used in browser. Authentication will be handled via the browser
- port - the port exposed by the application
- protocol - the protocol used by the endpoint
If not protocol is specified, it returns a list of dictionaries containing the values above,
for each protocol requested and waited
"""
Session.verify_feature_set("advanced")
if not getattr(self, "_external_endpoint_port", None):
LoggerRoot.get_base_logger().warning("No external endpoints have been requested")
if protocol:
return self._wait_for_external_endpoint(
wait_interval_seconds=wait_interval_seconds,
wait_timeout_seconds=wait_timeout_seconds,
protocol=protocol,
warn=True
)
results = []
protocols = ["http", "tcp"]
waited_protocols = []
for protocol_ in protocols:
start_time = time.time()
result = self._wait_for_external_endpoint(
wait_interval_seconds=wait_interval_seconds,
wait_timeout_seconds=wait_timeout_seconds,
protocol=protocol_,
warn=False,
)
elapsed = time.time() - start_time
if result:
results.append(result)
wait_timeout_seconds -= elapsed
if wait_timeout_seconds > 0 or result:
waited_protocols.append(protocol_)
unwaited_protocols = [p for p in protocols if p not in waited_protocols]
if wait_timeout_seconds <= 0 and unwaited_protocols:
LoggerRoot.get_base_logger().warning(
"Timeout exceeded while waiting for {} endpoint(s)".format(",".join(unwaited_protocols))
)
return results
def _wait_for_external_endpoint(
self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http", warn=True
):
if not self._external_endpoint_ports.get(protocol):
self.reload()
internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol])
if internal_port:
self._external_endpoint_ports[protocol] = internal_port
if not self._external_endpoint_ports.get(protocol):
if warn:
LoggerRoot.get_base_logger().warning("No external {} endpoints have been requested".format(protocol))
return None
start_time = time.time()
while True:
self.reload()
# noinspection PyProtectedMember
runtime_props = self._get_runtime_properties()
endpoint = runtime_props.get("endpoint")
browser_endpoint = runtime_props.get("browser_endpoint")
if not getattr(self, "_external_endpoint_port", None):
self._external_endpoint_port = runtime_props.get("_PORT")
endpoint, browser_endpoint = None, None
if protocol == "http":
endpoint = runtime_props.get("endpoint")
browser_endpoint = runtime_props.get("browser_endpoint")
elif protocol == "tcp":
health_check = runtime_props.get("upstream_task_port")
if health_check:
endpoint = (
runtime_props.get(self._external_endpoint_address_map[protocol])
+ ":"
+ str(runtime_props.get(self._external_endpoint_port_map[protocol]))
)
if endpoint or browser_endpoint:
return {
"endpoint": endpoint,
"browser_endpoint": browser_endpoint,
"port": self._external_endpoint_port,
"protocol": "http",
"port": self._external_endpoint_ports[protocol],
"protocol": protocol,
}
if time.time() >= start_time + wait_timeout_seconds:
LoggerRoot.get_base_logger().warning("Timeout exceeded while waiting for endpoint")
if warn:
LoggerRoot.get_base_logger().warning(
"Timeout exceeded while waiting for {} endpoint".format(protocol)
)
return None
time.sleep(wait_interval_seconds)
@ -941,23 +1008,36 @@ class Task(_Task):
- protocol - the protocol used by the endpoint
"""
Session.verify_feature_set("advanced")
if not getattr(self, "_external_endpoint_port", None):
self.reload()
self._external_endpoint_port = self._get_runtime_properties().get("_PORT")
if not getattr(self, "_external_endpoint_port", None):
LoggerRoot.get_base_logger().warning("No external endpoints have been requested")
return []
runtime_props = self._get_runtime_properties()
endpoint = runtime_props.get("endpoint")
browser_endpoint = runtime_props.get("browser_endpoint")
return [
{
"endpoint": endpoint,
"browser_endpoint": browser_endpoint,
"port": self._external_endpoint_port,
"protocol": "http",
}
]
results = []
for protocol in ["http", "tcp"]:
internal_port = runtime_props.get(self._external_endpoint_internal_port_map[protocol])
if internal_port:
self._external_endpoint_ports[protocol] = internal_port
else:
continue
endpoint, browser_endpoint = None, None
if protocol == "http":
endpoint = runtime_props.get("endpoint")
browser_endpoint = runtime_props.get("browser_endpoint")
elif protocol == "tcp":
health_check = runtime_props.get("upstream_task_port")
if health_check:
endpoint = (
runtime_props.get(self._external_endpoint_address_map[protocol])
+ ":"
+ str(runtime_props.get(self._external_endpoint_port_map[protocol]))
)
if endpoint or browser_endpoint:
results.append(
{
"endpoint": endpoint,
"browser_endpoint": browser_endpoint,
"port": internal_port,
"protocol": protocol,
}
)
return results
@classmethod
def create(