mirror of
https://github.com/clearml/clearml
synced 2025-03-14 15:48:21 +00:00
105 lines
4.1 KiB
Python
105 lines
4.1 KiB
Python
import sys
|
|
from logging import getLogger
|
|
from .frameworks import _patched_call # noqa
|
|
from .import_bind import PostImportHookPatching
|
|
from ..utilities.networking import get_private_ip
|
|
from ..config import running_remotely
|
|
|
|
|
|
class PatchGradio:
|
|
_current_task = None
|
|
__patched = False
|
|
|
|
_default_gradio_address = "0.0.0.0"
|
|
_default_gradio_port = 7860
|
|
_root_path_format = "/service/{}/"
|
|
__server_config_warning = set()
|
|
|
|
@classmethod
|
|
def update_current_task(cls, task=None):
|
|
cls._current_task = task
|
|
if cls.__patched:
|
|
return
|
|
if "gradio" in sys.modules:
|
|
cls.patch_gradio()
|
|
else:
|
|
PostImportHookPatching.add_on_import("gradio", cls.patch_gradio)
|
|
|
|
@classmethod
|
|
def patch_gradio(cls):
|
|
if cls.__patched:
|
|
return
|
|
# noinspection PyBroadException
|
|
try:
|
|
import gradio
|
|
|
|
gradio.routes.App.get_blocks = _patched_call(gradio.routes.App.get_blocks, PatchGradio._patched_get_blocks)
|
|
gradio.blocks.Blocks.launch = _patched_call(gradio.blocks.Blocks.launch, PatchGradio._patched_launch)
|
|
except Exception:
|
|
pass
|
|
cls.__patched = True
|
|
|
|
@staticmethod
|
|
def _patched_get_blocks(original_fn, *args, **kwargs):
|
|
blocks = original_fn(*args, **kwargs)
|
|
if not PatchGradio._current_task or not running_remotely():
|
|
return blocks
|
|
blocks.config["root"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
|
blocks.root = blocks.config["root"]
|
|
return blocks
|
|
|
|
@staticmethod
|
|
def _patched_launch(original_fn, *args, **kwargs):
|
|
if not PatchGradio._current_task:
|
|
return original_fn(*args, **kwargs)
|
|
PatchGradio.__warn_on_server_config(
|
|
kwargs.get("server_name"),
|
|
kwargs.get("server_port"),
|
|
kwargs.get("root_path")
|
|
)
|
|
if not running_remotely():
|
|
return original_fn(*args, **kwargs)
|
|
# noinspection PyProtectedMember
|
|
PatchGradio._current_task._set_runtime_properties(
|
|
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
|
|
)
|
|
PatchGradio._current_task.set_system_tags(["external_service"])
|
|
kwargs["server_name"] = PatchGradio._default_gradio_address
|
|
kwargs["server_port"] = PatchGradio._default_gradio_port
|
|
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
|
# noinspection PyBroadException
|
|
try:
|
|
return original_fn(*args, **kwargs)
|
|
except Exception:
|
|
del kwargs["root_path"]
|
|
return original_fn(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def __warn_on_server_config(cls, server_name, server_port, root_path):
|
|
if (server_name is None or server_name == PatchGradio._default_gradio_address) and \
|
|
(server_port is None and server_port == PatchGradio._default_gradio_port):
|
|
return
|
|
if (server_name, server_port, root_path) in cls.__server_config_warning:
|
|
return
|
|
cls.__server_config_warning.add((server_name, server_port, root_path))
|
|
if server_name is not None and server_port is not None:
|
|
server_config = "{}:{}".format(server_name, server_port)
|
|
what_to_ignore = "name and port"
|
|
elif server_name is not None:
|
|
server_config = str(server_name)
|
|
what_to_ignore = "name"
|
|
else:
|
|
server_config = str(server_port)
|
|
what_to_ignore = "port"
|
|
getLogger().warning(
|
|
"ClearML only supports '{}:{}' as the Gradio server. Ignoring {} '{}' in remote execution".format(
|
|
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
|
|
)
|
|
)
|
|
if root_path is not None:
|
|
getLogger().warning(
|
|
"ClearML will override root_path '{}' to '{}' in remote execution".format(
|
|
root_path, PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
|
)
|
|
)
|