Fix Python <3.8 compatibility

This commit is contained in:
allegroai 2024-05-12 08:57:56 +03:00
parent 66a7f5616c
commit 39d150bf56

View File

@ -1,5 +1,5 @@
import sys import sys
from typing import Any, Callable, Literal from typing import Any, Callable
import six import six
import threading import threading
@ -111,8 +111,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return ret return ret
@staticmethod @staticmethod
def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]): def _patch_lightning_io_internal(lightning_name):
# type: (str) -> None
"""
:param lightning_name: lightning module name, use "lightning" or "pytorch_lightning"
"""
try: try:
pytorch_lightning = importlib.import_module(lightning_name) pytorch_lightning = importlib.import_module(lightning_name)
except ImportError: except ImportError:
@ -122,8 +125,8 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
if lightning_name == "lightning": if lightning_name == "lightning":
pytorch_lightning = pytorch_lightning.pytorch pytorch_lightning = pytorch_lightning.pytorch
def patch_method(cls: type, method_name: str, def patch_method(cls, method_name, patched_method):
patched_method: Callable[..., Any]) -> None: # type: (type, str, Callable[..., Any]) -> None
""" """
Patch a method of a class if it exists. Patch a method of a class if it exists.