import sys

import six
from pathlib2 import Path

from ...binding.frameworks.base_bind import PatchBaseModelIO
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import Framework


class PatchPyTorchModelIO(PatchBaseModelIO):
    __main_task = None
    __patched = None
    __patched_lightning = None

    @staticmethod
    def update_current_task(task, **_):
        PatchPyTorchModelIO.__main_task = task
        PatchPyTorchModelIO._patch_model_io()
        PatchPyTorchModelIO._patch_lightning_io()
        PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
        PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)

    @staticmethod
    def _patch_model_io():
        if PatchPyTorchModelIO.__patched:
            return

        if 'torch' not in sys.modules:
            return

        PatchPyTorchModelIO.__patched = True

        # noinspection PyBroadException
        try:
            import torch  # noqa
            torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
            torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)

            # no need to worry about recursive calls, _patched_call takes care of that
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):
                torch.serialization._save = _patched_call(
                    torch.serialization._save, PatchPyTorchModelIO._save)  # noqa
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'):
                torch.serialization._load = _patched_call(
                    torch.serialization._load, PatchPyTorchModelIO._load)  # noqa
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'):
                torch.serialization._legacy_save = _patched_call(
                    torch.serialization._legacy_save, PatchPyTorchModelIO._save)  # noqa
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'):
                torch.serialization._legacy_load = _patched_call(
                    torch.serialization._legacy_load, PatchPyTorchModelIO._load)  # noqa
        except ImportError:
            pass
        except Exception:
            pass  # print('Failed patching pytorch')

    @staticmethod
    def _patch_lightning_io():
        if PatchPyTorchModelIO.__patched_lightning:
            return

        if 'pytorch_lightning' not in sys.modules:
            return

        PatchPyTorchModelIO.__patched_lightning = True

        # noinspection PyBroadException
        try:
            import pytorch_lightning  # noqa

            pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call(
                pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save)  # noqa

            pytorch_lightning.trainer.Trainer.restore = _patched_call(
                pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj)  # noqa
        except ImportError:
            pass
        except Exception:
            pass

        # noinspection PyBroadException
        try:
            import pytorch_lightning  # noqa

            # noinspection PyUnresolvedReferences
            pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \
                _patched_call(
                    pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
                    PatchPyTorchModelIO._save)  # noqa

            # noinspection PyUnresolvedReferences
            pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \
                _patched_call(
                    pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
                    PatchPyTorchModelIO._load_from_obj)  # noqa
        except ImportError:
            pass
        except Exception:
            pass

    @staticmethod
    def _save(original_fn, obj, f, *args, **kwargs):
        ret = original_fn(obj, f, *args, **kwargs)
        # if there is no main task or this is a nested call
        if not PatchPyTorchModelIO.__main_task:
            return ret

        # noinspection PyBroadException
        try:
            if isinstance(f, six.string_types):
                filename = f
            elif hasattr(f, 'as_posix'):
                filename = f.as_posix()
            elif hasattr(f, 'name'):
                # noinspection PyBroadException
                try:
                    f.flush()
                except Exception:
                    pass

                if not isinstance(f.name, six.string_types):
                    # Probably a BufferedRandom object that has no meaningful name (still no harm flushing)
                    return ret

                filename = f.name
            else:
                filename = None
        except Exception:
            filename = None

        # give the model a descriptive name based on the file name
        # noinspection PyBroadException
        try:
            model_name = Path(filename).stem if filename is not None else None
        except Exception:
            model_name = None

        WeightsFileHandler.create_output_model(
            obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name)

        return ret

    @staticmethod
    def _load(original_fn, f, *args, **kwargs):
        # if there is no main task or this is a nested call
        if not PatchPyTorchModelIO.__main_task:
            return original_fn(f, *args, **kwargs)

        # noinspection PyBroadException
        try:
            if isinstance(f, six.string_types):
                filename = f
            elif hasattr(f, 'as_posix'):
                filename = f.as_posix()
            elif hasattr(f, 'name'):
                filename = f.name
            else:
                filename = None
        except Exception:
            filename = None

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            filename = WeightsFileHandler.restore_weights_file(
                empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
            model = original_fn(filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(f, *args, **kwargs)
            WeightsFileHandler.restore_weights_file(
                empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass

        return model

    @staticmethod
    def _load_from_obj(original_fn, obj, f, *args, **kwargs):
        # if there is no main task or this is a nested call
        if not PatchPyTorchModelIO.__main_task:
            return original_fn(obj, f, *args, **kwargs)

        # noinspection PyBroadException
        try:
            if isinstance(f, six.string_types):
                filename = f
            elif hasattr(f, 'as_posix'):
                filename = f.as_posix()
            elif hasattr(f, 'name'):
                filename = f.name
            else:
                filename = None
        except Exception:
            filename = None

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            filename = WeightsFileHandler.restore_weights_file(
                empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
            model = original_fn(obj, filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(obj, f, *args, **kwargs)
            WeightsFileHandler.restore_weights_file(
                empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass

        return model