diff --git a/clearml/binding/gradio_bind.py b/clearml/binding/gradio_bind.py index 8bcb69f0..c21e8832 100644 --- a/clearml/binding/gradio_bind.py +++ b/clearml/binding/gradio_bind.py @@ -1,12 +1,9 @@ +import sys from logging import getLogger from .frameworks import _patched_call # noqa +from .import_bind import PostImportHookPatching from ..utilities.networking import get_private_ip -try: - import gradio -except ImportError: - gradio = None - class PatchGradio: _current_task = None @@ -19,17 +16,29 @@ class PatchGradio: @classmethod def update_current_task(cls, task=None): - if gradio is None: - return - cls._current_task = task + if cls.__patched: + return + if "gradio" in sys.modules: + cls.patch_gradio() + else: + PostImportHookPatching.add_on_import("gradio", cls.patch_gradio) + + @classmethod + def patch_gradio(cls): + if cls.__patched: + return + # noinspection PyBroadException + try: + import gradio - if not cls.__patched: - cls.__patched = True gradio.networking.start_server = _patched_call( gradio.networking.start_server, PatchGradio._patched_start_server ) gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init) + except Exception: + pass + cls.__patched = True @staticmethod def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):