Patch gradio only when imported

This commit is contained in:
Alex Burlacu 2023-03-23 13:24:13 +02:00
parent 2604401dd3
commit 6ff629c4e7

View File

@ -1,12 +1,9 @@
import sys
from logging import getLogger
from .frameworks import _patched_call # noqa
from .import_bind import PostImportHookPatching
from ..utilities.networking import get_private_ip
try:
import gradio
except ImportError:
gradio = None
class PatchGradio:
_current_task = None
@ -19,17 +16,29 @@ class PatchGradio:
@classmethod
def update_current_task(cls, task=None):
if gradio is None:
return
cls._current_task = task
if cls.__patched:
return
if "gradio" in sys.modules:
cls.patch_gradio()
else:
PostImportHookPatching.add_on_import("gradio", cls.patch_gradio)
@classmethod
def patch_gradio(cls):
if cls.__patched:
return
# noinspection PyBroadException
try:
import gradio
if not cls.__patched:
cls.__patched = True
gradio.networking.start_server = _patched_call(
gradio.networking.start_server, PatchGradio._patched_start_server
)
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
except Exception:
pass
cls.__patched = True
@staticmethod
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):