mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
Patch gradio only when imported
This commit is contained in:
parent
2604401dd3
commit
6ff629c4e7
@ -1,12 +1,9 @@
|
|||||||
|
import sys
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from .frameworks import _patched_call # noqa
|
from .frameworks import _patched_call # noqa
|
||||||
|
from .import_bind import PostImportHookPatching
|
||||||
from ..utilities.networking import get_private_ip
|
from ..utilities.networking import get_private_ip
|
||||||
|
|
||||||
try:
|
|
||||||
import gradio
|
|
||||||
except ImportError:
|
|
||||||
gradio = None
|
|
||||||
|
|
||||||
|
|
||||||
class PatchGradio:
|
class PatchGradio:
|
||||||
_current_task = None
|
_current_task = None
|
||||||
@ -19,17 +16,29 @@ class PatchGradio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, task=None):
|
def update_current_task(cls, task=None):
|
||||||
if gradio is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
cls._current_task = task
|
cls._current_task = task
|
||||||
|
if cls.__patched:
|
||||||
|
return
|
||||||
|
if "gradio" in sys.modules:
|
||||||
|
cls.patch_gradio()
|
||||||
|
else:
|
||||||
|
PostImportHookPatching.add_on_import("gradio", cls.patch_gradio)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_gradio(cls):
|
||||||
|
if cls.__patched:
|
||||||
|
return
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import gradio
|
||||||
|
|
||||||
if not cls.__patched:
|
|
||||||
cls.__patched = True
|
|
||||||
gradio.networking.start_server = _patched_call(
|
gradio.networking.start_server = _patched_call(
|
||||||
gradio.networking.start_server, PatchGradio._patched_start_server
|
gradio.networking.start_server, PatchGradio._patched_start_server
|
||||||
)
|
)
|
||||||
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
|
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cls.__patched = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
|
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user