mirror of
https://github.com/clearml/clearml
synced 2025-01-31 00:56:57 +00:00
Add TCP protocol support to Task.request_external_endpoint()
This commit is contained in:
parent
cd6d579944
commit
036d7b6ef2
162
clearml/task.py
162
clearml/task.py
@ -185,6 +185,11 @@ class Task(_Task):
|
||||
_launch_multi_node_section = "launch_multi_node"
|
||||
_launch_multi_node_instance_tag = "multi_node_instance"
|
||||
|
||||
_external_endpoint_port_map = {"http": "_PORT", "tcp": "external_tcp_port"}
|
||||
_external_endpoint_address_map = {"http": "_ADDRESS", "tcp": "external_address"}
|
||||
_external_endpoint_service_map = {"http": "EXTERNAL", "tcp": "EXTERNAL_TCP"}
|
||||
_external_endpoint_internal_port_map = {"http": "_PORT", "tcp": "upstream_task_port"}
|
||||
|
||||
class _ConnectedParametersType(object):
|
||||
argparse = "argument_parser"
|
||||
dictionary = "dictionary"
|
||||
@ -218,6 +223,7 @@ class Task(_Task):
|
||||
self._resource_monitor = None
|
||||
self._calling_filename = None
|
||||
self._remote_functions_generated = {}
|
||||
self._external_endpoint_ports = {}
|
||||
# register atexit, so that we mark the task as stopped
|
||||
self._at_exit_called = False
|
||||
|
||||
@ -850,7 +856,7 @@ class Task(_Task):
|
||||
Request an external endpoint for an application
|
||||
|
||||
:param port: Port the application is listening to
|
||||
:param protocol: As of now, only `http` is supported
|
||||
:param protocol: `http` or `tcp`
|
||||
:param wait: If True, wait for the endpoint to be assigned
|
||||
:param wait_interval_seconds: The poll frequency when waiting for the endpoint
|
||||
:param wait_timeout_seconds: If this timeout is exceeded while waiting for the endpoint,
|
||||
@ -865,67 +871,128 @@ class Task(_Task):
|
||||
- protocol - the protocol used by the endpoint
|
||||
"""
|
||||
Session.verify_feature_set("advanced")
|
||||
if not getattr(self, "_external_endpoint_port", None):
|
||||
if protocol not in self._external_endpoint_port_map.keys():
|
||||
raise ValueError("Invalid protocol: {}".format(protocol))
|
||||
|
||||
if not self._external_endpoint_ports.get(protocol):
|
||||
self.reload()
|
||||
assigned_port = self._get_runtime_properties().get("_PORT")
|
||||
if assigned_port:
|
||||
self._external_endpoint_port = assigned_port
|
||||
if getattr(self, "_external_endpoint_port", None):
|
||||
if self._external_endpoint_port != port: # noqa
|
||||
internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol])
|
||||
if internal_port:
|
||||
self._external_endpoint_ports[protocol] = internal_port
|
||||
if self._external_endpoint_ports.get(protocol):
|
||||
if self._external_endpoint_ports.get(protocol) != port: # noqa
|
||||
raise ValueError(
|
||||
"Only one endpoint can be requested at the moment. Port already exposed is: {}".format(
|
||||
self._external_endpoint_port
|
||||
"Only one endpoint per protocol can be requested at the moment. Port already exposed is: {}".format(
|
||||
self._external_endpoint_ports.get(protocol)
|
||||
)
|
||||
)
|
||||
return
|
||||
# noinspection PyProtectedMember
|
||||
self._set_runtime_properties(
|
||||
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": port}
|
||||
{
|
||||
"_SERVICE": self._external_endpoint_service_map[protocol],
|
||||
self._external_endpoint_address_map[protocol]: get_private_ip(),
|
||||
self._external_endpoint_port_map[protocol]: port,
|
||||
}
|
||||
)
|
||||
self.set_system_tags((self.get_system_tags() or []) + ["external_service"])
|
||||
self._external_endpoint_port = port
|
||||
self._external_endpoint_ports[protocol] = port
|
||||
if wait:
|
||||
return self.wait_for_external_endpoint(wait_interval_seconds=wait_interval_seconds)
|
||||
return self.wait_for_external_endpoint(wait_interval_seconds=wait_interval_seconds, protocol=protocol)
|
||||
return None
|
||||
|
||||
def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0):
|
||||
# type: (float) -> Optional[Dict]
|
||||
def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http"):
|
||||
# type: (float, float, Optional[str]) -> Union[Optional[Dict], List[Optional[Dict]]]
|
||||
"""
|
||||
Wait for an external endpoint to be assigned
|
||||
|
||||
:param wait_interval_seconds: The poll frequency when waiting for the endpoint
|
||||
:param wait_timeout_seconds: If this timeout is exceeded while waiting for the endpoint,
|
||||
the method will no longer wait
|
||||
:param protocol: `http` or `tcp`. Wait for an endpoint to be assigned based on the protocol.
|
||||
If None, wait for all supported protocols
|
||||
|
||||
:return: If no endpoint could be found while waiting, this mehtod returns None.
|
||||
Otherwise, it returns a dictionary containing the following values:
|
||||
If a protocol has been specified, it returns a dictionary containing the following values:
|
||||
- endpoint - raw endpoint. One might need to authenticate in order to use this endpoint
|
||||
- browser_endpoint - endpoint to be used in browser. Authentication will be handled via the browser
|
||||
- port - the port exposed by the application
|
||||
- protocol - the protocol used by the endpoint
|
||||
If not protocol is specified, it returns a list of dictionaries containing the values above,
|
||||
for each protocol requested and waited
|
||||
"""
|
||||
Session.verify_feature_set("advanced")
|
||||
if not getattr(self, "_external_endpoint_port", None):
|
||||
LoggerRoot.get_base_logger().warning("No external endpoints have been requested")
|
||||
if protocol:
|
||||
return self._wait_for_external_endpoint(
|
||||
wait_interval_seconds=wait_interval_seconds,
|
||||
wait_timeout_seconds=wait_timeout_seconds,
|
||||
protocol=protocol,
|
||||
warn=True
|
||||
)
|
||||
results = []
|
||||
protocols = ["http", "tcp"]
|
||||
waited_protocols = []
|
||||
for protocol_ in protocols:
|
||||
start_time = time.time()
|
||||
result = self._wait_for_external_endpoint(
|
||||
wait_interval_seconds=wait_interval_seconds,
|
||||
wait_timeout_seconds=wait_timeout_seconds,
|
||||
protocol=protocol_,
|
||||
warn=False,
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
if result:
|
||||
results.append(result)
|
||||
wait_timeout_seconds -= elapsed
|
||||
if wait_timeout_seconds > 0 or result:
|
||||
waited_protocols.append(protocol_)
|
||||
unwaited_protocols = [p for p in protocols if p not in waited_protocols]
|
||||
if wait_timeout_seconds <= 0 and unwaited_protocols:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Timeout exceeded while waiting for {} endpoint(s)".format(",".join(unwaited_protocols))
|
||||
)
|
||||
return results
|
||||
|
||||
def _wait_for_external_endpoint(
|
||||
self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http", warn=True
|
||||
):
|
||||
if not self._external_endpoint_ports.get(protocol):
|
||||
self.reload()
|
||||
internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol])
|
||||
if internal_port:
|
||||
self._external_endpoint_ports[protocol] = internal_port
|
||||
if not self._external_endpoint_ports.get(protocol):
|
||||
if warn:
|
||||
LoggerRoot.get_base_logger().warning("No external {} endpoints have been requested".format(protocol))
|
||||
return None
|
||||
start_time = time.time()
|
||||
while True:
|
||||
self.reload()
|
||||
# noinspection PyProtectedMember
|
||||
runtime_props = self._get_runtime_properties()
|
||||
endpoint = runtime_props.get("endpoint")
|
||||
browser_endpoint = runtime_props.get("browser_endpoint")
|
||||
if not getattr(self, "_external_endpoint_port", None):
|
||||
self._external_endpoint_port = runtime_props.get("_PORT")
|
||||
endpoint, browser_endpoint = None, None
|
||||
if protocol == "http":
|
||||
endpoint = runtime_props.get("endpoint")
|
||||
browser_endpoint = runtime_props.get("browser_endpoint")
|
||||
elif protocol == "tcp":
|
||||
health_check = runtime_props.get("upstream_task_port")
|
||||
if health_check:
|
||||
endpoint = (
|
||||
runtime_props.get(self._external_endpoint_address_map[protocol])
|
||||
+ ":"
|
||||
+ str(runtime_props.get(self._external_endpoint_port_map[protocol]))
|
||||
)
|
||||
if endpoint or browser_endpoint:
|
||||
return {
|
||||
"endpoint": endpoint,
|
||||
"browser_endpoint": browser_endpoint,
|
||||
"port": self._external_endpoint_port,
|
||||
"protocol": "http",
|
||||
"port": self._external_endpoint_ports[protocol],
|
||||
"protocol": protocol,
|
||||
}
|
||||
if time.time() >= start_time + wait_timeout_seconds:
|
||||
LoggerRoot.get_base_logger().warning("Timeout exceeded while waiting for endpoint")
|
||||
if warn:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Timeout exceeded while waiting for {} endpoint".format(protocol)
|
||||
)
|
||||
return None
|
||||
time.sleep(wait_interval_seconds)
|
||||
|
||||
@ -941,23 +1008,36 @@ class Task(_Task):
|
||||
- protocol - the protocol used by the endpoint
|
||||
"""
|
||||
Session.verify_feature_set("advanced")
|
||||
if not getattr(self, "_external_endpoint_port", None):
|
||||
self.reload()
|
||||
self._external_endpoint_port = self._get_runtime_properties().get("_PORT")
|
||||
if not getattr(self, "_external_endpoint_port", None):
|
||||
LoggerRoot.get_base_logger().warning("No external endpoints have been requested")
|
||||
return []
|
||||
runtime_props = self._get_runtime_properties()
|
||||
endpoint = runtime_props.get("endpoint")
|
||||
browser_endpoint = runtime_props.get("browser_endpoint")
|
||||
return [
|
||||
{
|
||||
"endpoint": endpoint,
|
||||
"browser_endpoint": browser_endpoint,
|
||||
"port": self._external_endpoint_port,
|
||||
"protocol": "http",
|
||||
}
|
||||
]
|
||||
results = []
|
||||
for protocol in ["http", "tcp"]:
|
||||
internal_port = runtime_props.get(self._external_endpoint_internal_port_map[protocol])
|
||||
if internal_port:
|
||||
self._external_endpoint_ports[protocol] = internal_port
|
||||
else:
|
||||
continue
|
||||
endpoint, browser_endpoint = None, None
|
||||
if protocol == "http":
|
||||
endpoint = runtime_props.get("endpoint")
|
||||
browser_endpoint = runtime_props.get("browser_endpoint")
|
||||
elif protocol == "tcp":
|
||||
health_check = runtime_props.get("upstream_task_port")
|
||||
if health_check:
|
||||
endpoint = (
|
||||
runtime_props.get(self._external_endpoint_address_map[protocol])
|
||||
+ ":"
|
||||
+ str(runtime_props.get(self._external_endpoint_port_map[protocol]))
|
||||
)
|
||||
if endpoint or browser_endpoint:
|
||||
results.append(
|
||||
{
|
||||
"endpoint": endpoint,
|
||||
"browser_endpoint": browser_endpoint,
|
||||
"port": internal_port,
|
||||
"protocol": protocol,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
Loading…
Reference in New Issue
Block a user