mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Add Gradio binding support
This commit is contained in:
parent
22715cda19
commit
72b341ee51
78
clearml/binding/gradio_bind.py
Normal file
78
clearml/binding/gradio_bind.py
Normal file
@ -0,0 +1,78 @@
|
||||
from logging import getLogger
|
||||
from .frameworks import _patched_call # noqa
|
||||
from ..utilities.networking import get_private_ip
|
||||
|
||||
try:
|
||||
import gradio
|
||||
except ImportError:
|
||||
gradio = None
|
||||
|
||||
|
||||
class PatchGradio:
|
||||
_current_task = None
|
||||
__patched = False
|
||||
|
||||
_default_gradio_address = "0.0.0.0"
|
||||
_default_gradio_port = 7860
|
||||
_root_path_format = "/service/{}"
|
||||
__server_config_warning = set()
|
||||
|
||||
@classmethod
|
||||
def update_current_task(cls, task=None):
|
||||
if gradio is None:
|
||||
return
|
||||
|
||||
cls._current_task = task
|
||||
|
||||
if not cls.__patched:
|
||||
cls.__patched = True
|
||||
gradio.networking.start_server = _patched_call(
|
||||
gradio.networking.start_server, PatchGradio._patched_start_server
|
||||
)
|
||||
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
|
||||
|
||||
@staticmethod
|
||||
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
|
||||
if not PatchGradio._current_task:
|
||||
return original_fn(self, server_name, server_port, *args, **kwargs)
|
||||
PatchGradio._current_task._set_runtime_properties(
|
||||
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
|
||||
)
|
||||
PatchGradio._current_task.set_system_tags(["external_service"])
|
||||
PatchGradio.__warn_on_server_config(server_name, server_port)
|
||||
server_name = PatchGradio._default_gradio_address
|
||||
server_port = PatchGradio._default_gradio_port
|
||||
return original_fn(self, server_name, server_port, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _patched_init(original_fn, *args, **kwargs):
|
||||
if not PatchGradio._current_task:
|
||||
return original_fn(*args, **kwargs)
|
||||
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
|
||||
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||
kwargs["root_path_in_servers"] = False
|
||||
kwargs["server_name"] = PatchGradio._default_gradio_address
|
||||
kwargs["server_port"] = PatchGradio._default_gradio_port
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def __warn_on_server_config(cls, server_name, server_port):
|
||||
if server_name is None and server_port is None:
|
||||
return
|
||||
if server_name is not None and server_port is not None:
|
||||
server_config = "{}:{}".format(server_name, server_port)
|
||||
what_to_ignore = "name and port"
|
||||
elif server_name is not None:
|
||||
server_config = str(server_name)
|
||||
what_to_ignore = "name"
|
||||
else:
|
||||
server_config = str(server_port)
|
||||
what_to_ignore = "port"
|
||||
if server_config in cls.__server_config_warning:
|
||||
return
|
||||
cls.__server_config_warning.add(server_config)
|
||||
getLogger().warning(
|
||||
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format(
|
||||
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
|
||||
)
|
||||
)
|
@ -74,6 +74,7 @@ from .binding.hydra_bind import PatchHydra
|
||||
from .binding.click_bind import PatchClick
|
||||
from .binding.fire_bind import PatchFire
|
||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||
from .binding.gradio_bind import PatchGradio
|
||||
from .binding.frameworks import WeightsFileHandler
|
||||
from .config import (
|
||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||
@ -402,7 +403,7 @@ class Task(_Task):
|
||||
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
|
||||
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
|
||||
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
|
||||
'joblib': True, 'megengine': True, 'catboost': True
|
||||
'joblib': True, 'megengine': True, 'catboost': True, 'gradio': True
|
||||
}
|
||||
|
||||
.. code-block:: py
|
||||
@ -689,6 +690,8 @@ class Task(_Task):
|
||||
PatchFastai.update_current_task(task)
|
||||
if should_connect("lightgbm"):
|
||||
PatchLIGHTgbmModelIO.update_current_task(task)
|
||||
if should_connect("gradio"):
|
||||
PatchGradio.update_current_task(task)
|
||||
|
||||
cls.__add_model_wildcards(auto_connect_frameworks)
|
||||
|
||||
|
88
clearml/utilities/networking.py
Normal file
88
clearml/utilities/networking.py
Normal file
@ -0,0 +1,88 @@
|
||||
import requests
|
||||
import socket
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_private_ip():
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the private IP of this machine
|
||||
|
||||
:return: A string representing the IP of this machine
|
||||
"""
|
||||
approaches = (
|
||||
_get_private_ip_from_socket,
|
||||
_get_private_ip_from_subprocess,
|
||||
)
|
||||
|
||||
for approach in approaches:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return approach()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise Exception("error getting private IP")
|
||||
|
||||
|
||||
def get_public_ip():
|
||||
# type: () -> Optional[str]
|
||||
"""
|
||||
Get the public IP of this machine. External services such as `https://api.ipify.org` or `https://ident.me`
|
||||
are used to get the IP
|
||||
|
||||
:return: A string representing the IP of this machine or `None` if getting the IP failed
|
||||
"""
|
||||
for external_service in ["https://api.ipify.org", "https://ident.me"]:
|
||||
ip = get_public_ip_from_external_service(external_service)
|
||||
if ip:
|
||||
return ip
|
||||
return None
|
||||
|
||||
|
||||
def get_public_ip_from_external_service(external_service, timeout=5):
|
||||
# type: (str, Optional[int]) -> Optional[str]
|
||||
"""
|
||||
Get the public IP of this machine from an external service.
|
||||
Fetching the IP is done via a GET request. The whole content of the request
|
||||
should be the IP address
|
||||
|
||||
:param external_service: The address of the extrenal service
|
||||
:param timeout: The GET request timeout
|
||||
|
||||
:return: A string representing the IP of this machine or `None` if getting the IP failed
|
||||
"""
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
response = requests.get(external_service, timeout=timeout)
|
||||
if not response.ok:
|
||||
return None
|
||||
ip = response.content.decode("utf8")
|
||||
# check that we actually received an IP address
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET, ip)
|
||||
return ip
|
||||
except Exception:
|
||||
socket.inet_pton(socket.AF_INET6, ip)
|
||||
return ip
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_private_ip_from_socket():
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(0)
|
||||
try:
|
||||
s.connect(("8.8.8.8", 1))
|
||||
ip = s.getsockname()[0]
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
|
||||
def _get_private_ip_from_subprocess():
|
||||
return subprocess.check_output("hostname -I", shell=True).split()[0].decode("utf-8")
|
Loading…
Reference in New Issue
Block a user