Add Gradio binding support

This commit is contained in:
allegroai 2023-02-28 17:12:11 +02:00
parent 22715cda19
commit 72b341ee51
3 changed files with 170 additions and 1 deletions

View File

@ -0,0 +1,78 @@
from logging import getLogger
from .frameworks import _patched_call # noqa
from ..utilities.networking import get_private_ip
try:
import gradio
except ImportError:
gradio = None
class PatchGradio:
_current_task = None
__patched = False
_default_gradio_address = "0.0.0.0"
_default_gradio_port = 7860
_root_path_format = "/service/{}"
__server_config_warning = set()
@classmethod
def update_current_task(cls, task=None):
if gradio is None:
return
cls._current_task = task
if not cls.__patched:
cls.__patched = True
gradio.networking.start_server = _patched_call(
gradio.networking.start_server, PatchGradio._patched_start_server
)
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
@staticmethod
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
if not PatchGradio._current_task:
return original_fn(self, server_name, server_port, *args, **kwargs)
PatchGradio._current_task._set_runtime_properties(
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
)
PatchGradio._current_task.set_system_tags(["external_service"])
PatchGradio.__warn_on_server_config(server_name, server_port)
server_name = PatchGradio._default_gradio_address
server_port = PatchGradio._default_gradio_port
return original_fn(self, server_name, server_port, *args, **kwargs)
@staticmethod
def _patched_init(original_fn, *args, **kwargs):
if not PatchGradio._current_task:
return original_fn(*args, **kwargs)
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
kwargs["root_path_in_servers"] = False
kwargs["server_name"] = PatchGradio._default_gradio_address
kwargs["server_port"] = PatchGradio._default_gradio_port
return original_fn(*args, **kwargs)
@classmethod
def __warn_on_server_config(cls, server_name, server_port):
if server_name is None and server_port is None:
return
if server_name is not None and server_port is not None:
server_config = "{}:{}".format(server_name, server_port)
what_to_ignore = "name and port"
elif server_name is not None:
server_config = str(server_name)
what_to_ignore = "name"
else:
server_config = str(server_port)
what_to_ignore = "port"
if server_config in cls.__server_config_warning:
return
cls.__server_config_warning.add(server_config)
getLogger().warning(
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format(
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
)
)

View File

@ -74,6 +74,7 @@ from .binding.hydra_bind import PatchHydra
from .binding.click_bind import PatchClick
from .binding.fire_bind import PatchFire
from .binding.jsonargs_bind import PatchJsonArgParse
from .binding.gradio_bind import PatchGradio
from .binding.frameworks import WeightsFileHandler
from .config import (
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
@ -402,7 +403,7 @@ class Task(_Task):
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
'joblib': True, 'megengine': True, 'catboost': True
'joblib': True, 'megengine': True, 'catboost': True, 'gradio': True
}
.. code-block:: py
@ -689,6 +690,8 @@ class Task(_Task):
PatchFastai.update_current_task(task)
if should_connect("lightgbm"):
PatchLIGHTgbmModelIO.update_current_task(task)
if should_connect("gradio"):
PatchGradio.update_current_task(task)
cls.__add_model_wildcards(auto_connect_frameworks)

View File

@ -0,0 +1,88 @@
import requests
import socket
import subprocess
from typing import Optional
def get_private_ip():
# type: () -> str
"""
Get the private IP of this machine
:return: A string representing the IP of this machine
"""
approaches = (
_get_private_ip_from_socket,
_get_private_ip_from_subprocess,
)
for approach in approaches:
# noinspection PyBroadException
try:
return approach()
except Exception:
continue
raise Exception("error getting private IP")
def get_public_ip():
# type: () -> Optional[str]
"""
Get the public IP of this machine. External services such as `https://api.ipify.org` or `https://ident.me`
are used to get the IP
:return: A string representing the IP of this machine or `None` if getting the IP failed
"""
for external_service in ["https://api.ipify.org", "https://ident.me"]:
ip = get_public_ip_from_external_service(external_service)
if ip:
return ip
return None
def get_public_ip_from_external_service(external_service, timeout=5):
# type: (str, Optional[int]) -> Optional[str]
"""
Get the public IP of this machine from an external service.
Fetching the IP is done via a GET request. The whole content of the request
should be the IP address
:param external_service: The address of the extrenal service
:param timeout: The GET request timeout
:return: A string representing the IP of this machine or `None` if getting the IP failed
"""
# noinspection PyBroadException
try:
response = requests.get(external_service, timeout=timeout)
if not response.ok:
return None
ip = response.content.decode("utf8")
# check that we actually received an IP address
# noinspection PyBroadException
try:
socket.inet_pton(socket.AF_INET, ip)
return ip
except Exception:
socket.inet_pton(socket.AF_INET6, ip)
return ip
except Exception:
return None
def _get_private_ip_from_socket():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
try:
s.connect(("8.8.8.8", 1))
ip = s.getsockname()[0]
except Exception as e:
raise e
finally:
s.close()
return ip
def _get_private_ip_from_subprocess():
return subprocess.check_output("hostname -I", shell=True).split()[0].decode("utf-8")