diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index fbe1de42..6f6651ee 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, Literal +from typing import Any, Callable import six import threading @@ -111,8 +111,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO): return ret @staticmethod - def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]): - + def _patch_lightning_io_internal(lightning_name): + # type: (str) -> None + """ + :param lightning_name: lightning module name, use "lightning" or "pytorch_lightning" + """ try: pytorch_lightning = importlib.import_module(lightning_name) except ImportError: @@ -122,8 +125,8 @@ class PatchPyTorchModelIO(PatchBaseModelIO): if lightning_name == "lightning": pytorch_lightning = pytorch_lightning.pytorch - def patch_method(cls: type, method_name: str, - patched_method: Callable[..., Any]) -> None: + def patch_method(cls, method_name, patched_method): + # type: (type, str, Callable[..., Any]) -> None """ Patch a method of a class if it exists.