Fix Python <3.8 compatibility

This commit is contained in:
allegroai 2024-05-12 08:57:56 +03:00
parent 66a7f5616c
commit 39d150bf56

View File

@ -1,5 +1,5 @@
import sys
from typing import Any, Callable, Literal
from typing import Any, Callable
import six
import threading
@ -111,8 +111,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return ret
@staticmethod
def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]):
def _patch_lightning_io_internal(lightning_name):
# type: (str) -> None
"""
:param lightning_name: lightning module name, use "lightning" or "pytorch_lightning"
"""
try:
pytorch_lightning = importlib.import_module(lightning_name)
except ImportError:
@ -122,8 +125,8 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
if lightning_name == "lightning":
pytorch_lightning = pytorch_lightning.pytorch
def patch_method(cls: type, method_name: str,
patched_method: Callable[..., Any]) -> None:
def patch_method(cls, method_name, patched_method):
# type: (type, str, Callable[..., Any]) -> None
"""
Patch a method of a class if it exists.