mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add Gradio binding support
This commit is contained in:
parent
22715cda19
commit
72b341ee51
78
clearml/binding/gradio_bind.py
Normal file
78
clearml/binding/gradio_bind.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from logging import getLogger
|
||||||
|
from .frameworks import _patched_call # noqa
|
||||||
|
from ..utilities.networking import get_private_ip
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gradio
|
||||||
|
except ImportError:
|
||||||
|
gradio = None
|
||||||
|
|
||||||
|
|
||||||
|
class PatchGradio:
|
||||||
|
_current_task = None
|
||||||
|
__patched = False
|
||||||
|
|
||||||
|
_default_gradio_address = "0.0.0.0"
|
||||||
|
_default_gradio_port = 7860
|
||||||
|
_root_path_format = "/service/{}"
|
||||||
|
__server_config_warning = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_current_task(cls, task=None):
|
||||||
|
if gradio is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cls._current_task = task
|
||||||
|
|
||||||
|
if not cls.__patched:
|
||||||
|
cls.__patched = True
|
||||||
|
gradio.networking.start_server = _patched_call(
|
||||||
|
gradio.networking.start_server, PatchGradio._patched_start_server
|
||||||
|
)
|
||||||
|
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
|
||||||
|
if not PatchGradio._current_task:
|
||||||
|
return original_fn(self, server_name, server_port, *args, **kwargs)
|
||||||
|
PatchGradio._current_task._set_runtime_properties(
|
||||||
|
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
|
||||||
|
)
|
||||||
|
PatchGradio._current_task.set_system_tags(["external_service"])
|
||||||
|
PatchGradio.__warn_on_server_config(server_name, server_port)
|
||||||
|
server_name = PatchGradio._default_gradio_address
|
||||||
|
server_port = PatchGradio._default_gradio_port
|
||||||
|
return original_fn(self, server_name, server_port, *args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_init(original_fn, *args, **kwargs):
|
||||||
|
if not PatchGradio._current_task:
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
|
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
|
||||||
|
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||||
|
kwargs["root_path_in_servers"] = False
|
||||||
|
kwargs["server_name"] = PatchGradio._default_gradio_address
|
||||||
|
kwargs["server_port"] = PatchGradio._default_gradio_port
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __warn_on_server_config(cls, server_name, server_port):
|
||||||
|
if server_name is None and server_port is None:
|
||||||
|
return
|
||||||
|
if server_name is not None and server_port is not None:
|
||||||
|
server_config = "{}:{}".format(server_name, server_port)
|
||||||
|
what_to_ignore = "name and port"
|
||||||
|
elif server_name is not None:
|
||||||
|
server_config = str(server_name)
|
||||||
|
what_to_ignore = "name"
|
||||||
|
else:
|
||||||
|
server_config = str(server_port)
|
||||||
|
what_to_ignore = "port"
|
||||||
|
if server_config in cls.__server_config_warning:
|
||||||
|
return
|
||||||
|
cls.__server_config_warning.add(server_config)
|
||||||
|
getLogger().warning(
|
||||||
|
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format(
|
||||||
|
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
|
||||||
|
)
|
||||||
|
)
|
@ -74,6 +74,7 @@ from .binding.hydra_bind import PatchHydra
|
|||||||
from .binding.click_bind import PatchClick
|
from .binding.click_bind import PatchClick
|
||||||
from .binding.fire_bind import PatchFire
|
from .binding.fire_bind import PatchFire
|
||||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||||
|
from .binding.gradio_bind import PatchGradio
|
||||||
from .binding.frameworks import WeightsFileHandler
|
from .binding.frameworks import WeightsFileHandler
|
||||||
from .config import (
|
from .config import (
|
||||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||||
@ -402,7 +403,7 @@ class Task(_Task):
|
|||||||
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
|
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
|
||||||
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
|
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
|
||||||
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
|
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
|
||||||
'joblib': True, 'megengine': True, 'catboost': True
|
'joblib': True, 'megengine': True, 'catboost': True, 'gradio': True
|
||||||
}
|
}
|
||||||
|
|
||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
@ -689,6 +690,8 @@ class Task(_Task):
|
|||||||
PatchFastai.update_current_task(task)
|
PatchFastai.update_current_task(task)
|
||||||
if should_connect("lightgbm"):
|
if should_connect("lightgbm"):
|
||||||
PatchLIGHTgbmModelIO.update_current_task(task)
|
PatchLIGHTgbmModelIO.update_current_task(task)
|
||||||
|
if should_connect("gradio"):
|
||||||
|
PatchGradio.update_current_task(task)
|
||||||
|
|
||||||
cls.__add_model_wildcards(auto_connect_frameworks)
|
cls.__add_model_wildcards(auto_connect_frameworks)
|
||||||
|
|
||||||
|
88
clearml/utilities/networking.py
Normal file
88
clearml/utilities/networking.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import requests
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def get_private_ip():
|
||||||
|
# type: () -> str
|
||||||
|
"""
|
||||||
|
Get the private IP of this machine
|
||||||
|
|
||||||
|
:return: A string representing the IP of this machine
|
||||||
|
"""
|
||||||
|
approaches = (
|
||||||
|
_get_private_ip_from_socket,
|
||||||
|
_get_private_ip_from_subprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
for approach in approaches:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
return approach()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise Exception("error getting private IP")
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_ip():
|
||||||
|
# type: () -> Optional[str]
|
||||||
|
"""
|
||||||
|
Get the public IP of this machine. External services such as `https://api.ipify.org` or `https://ident.me`
|
||||||
|
are used to get the IP
|
||||||
|
|
||||||
|
:return: A string representing the IP of this machine or `None` if getting the IP failed
|
||||||
|
"""
|
||||||
|
for external_service in ["https://api.ipify.org", "https://ident.me"]:
|
||||||
|
ip = get_public_ip_from_external_service(external_service)
|
||||||
|
if ip:
|
||||||
|
return ip
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_ip_from_external_service(external_service, timeout=5):
|
||||||
|
# type: (str, Optional[int]) -> Optional[str]
|
||||||
|
"""
|
||||||
|
Get the public IP of this machine from an external service.
|
||||||
|
Fetching the IP is done via a GET request. The whole content of the request
|
||||||
|
should be the IP address
|
||||||
|
|
||||||
|
:param external_service: The address of the extrenal service
|
||||||
|
:param timeout: The GET request timeout
|
||||||
|
|
||||||
|
:return: A string representing the IP of this machine or `None` if getting the IP failed
|
||||||
|
"""
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
response = requests.get(external_service, timeout=timeout)
|
||||||
|
if not response.ok:
|
||||||
|
return None
|
||||||
|
ip = response.content.decode("utf8")
|
||||||
|
# check that we actually received an IP address
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
socket.inet_pton(socket.AF_INET, ip)
|
||||||
|
return ip
|
||||||
|
except Exception:
|
||||||
|
socket.inet_pton(socket.AF_INET6, ip)
|
||||||
|
return ip
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_private_ip_from_socket():
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
s.settimeout(0)
|
||||||
|
try:
|
||||||
|
s.connect(("8.8.8.8", 1))
|
||||||
|
ip = s.getsockname()[0]
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
s.close()
|
||||||
|
return ip
|
||||||
|
|
||||||
|
|
||||||
|
def _get_private_ip_from_subprocess():
|
||||||
|
return subprocess.check_output("hostname -I", shell=True).split()[0].decode("utf-8")
|
Loading…
Reference in New Issue
Block a user