import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain
from typing import *

import numpy as np

from .batch_agg import *
from .callbacks import *
from .data import DataStream
from .events import Event, EventHost
from .mlstorage import ExperimentDoc
from .stage import Stage, StageType
from .typing_ import *
from .utils import to_number_or_numpy, get_array_shape, ALL, NOT_SET, DocInherit

__all__ = [
    'TrainLoop', 'ValidationLoop', 'TestLoop', 'PredictLoop',
]


class _BaseLoopEventCallback(Callback):

    # a sufficiently small priority, such that it should run before almost
    # all callbacks
    priority = -999999

    loop: 'BaseLoop'
    stage: Stage

    def __init__(self,
                 loop: 'BaseLoop',
                 stage: Stage):
        self.loop = loop
        self.stage = stage

    def on_stage_begin(self, data: CallbackData):
        if data.stage == self.stage:
            self.loop.on_begin.fire()

    def on_stage_end(self, data: CallbackData):
        if data.stage == self.stage:
            self.loop.on_end.fire()

    def on_batch_begin(self, data: CallbackData):
        if data.stage == self.stage:
            self.loop.on_batch_begin.fire()

    def on_batch_end(self, data: CallbackData):
        if data.stage == self.stage:
            self.loop.on_batch_end.fire()


class _TrainLoopEventCallback(_BaseLoopEventCallback):

    loop: 'TrainLoop'

    def on_epoch_begin(self, data: CallbackData):
        if data.stage == self.stage:
            self.loop.on_epoch_begin.fire()

    def on_epoch_end(self, data: CallbackData):
        if data.stage == self.stage:
            self.loop.on_epoch_end.fire()


class BaseLoop(metaclass=DocInherit):

    _LoopEventCallbackClass = _BaseLoopEventCallback
    """Callback class that fires the loop events."""

    RUN_BATCHES_DEFAULT_METRICS = ALL
    """Default value for the `metrics` arg in :meth:`run_batches()`."""

    RUN_BATCHES_DEFAULT_OUTPUTS = ()
    """Default value for the `outputs` arg in :meth:`run_batches()`."""

    _callbacks: CallbackList

    logger: LoggerCallback
    """
    The last :class:`LoggerCallback` granted to this loop.

    If no :class:`LoggerCallback` is granted, then this should be the
    auto-created logger callback instance.
    """

    parent: Optional['BaseLoop']
    """The parent loop."""

    _child_stack: List['BaseLoop']
    """The stack of open child loops."""

    _stage: Stage  # the active stage
    _remote_doc: Optional[ExperimentDoc]
    _batch_metrics: Dict[str, Any]
    _epoch_metrics: Dict[str, Any]
    _stage_metrics: Dict[str, Any]

    events: EventHost
    """The event host of this loop object."""

    on_begin: Event
    """Event triggered when loop begins, with signature ``() -> None``."""

    on_end: Event
    """Event triggered when loop ends, with signature ``() -> None``."""

    on_batch_begin: Event
    """Event triggered when batch begins, with signature ``() -> None``."""

    on_batch_end: Event
    """Event triggered when batch ends, with signature ``() -> None``."""

    def __init__(self,
                 stage: Stage,
                 remote_doc: Optional[ExperimentDoc] = NOT_SET,
                 callbacks: Sequence[Callback] = (),
                 parent: Optional['BaseLoop'] = None):
        """
        Construct a new :class:`BaseLoop`.

        Args:
            stage: The stage object.  Note that the `callbacks` list of
                `stage` will be altered.  Do not share stages between
                different loops.
            remote_doc: The experiment document object.
            callbacks: The callbacks.
            parent: The parent loop.
        """
        # construct the default remote doc object, if it is `NOT_SET`
        if remote_doc is NOT_SET:
            remote_doc = ExperimentDoc.default_doc()

        # merge `callbacks` with `stage.callbacks`, sort them into proper
        # order, and add default logger callback if not given.
        callbacks = CallbackList(list(chain(stage.callbacks, callbacks)))
        logger = None
        for cb in reversed(callbacks):
            if isinstance(cb, LoggerCallback):
                logger = cb
                break
        if logger is None:
            logger = LoggerCallback(remote_doc=remote_doc)
            callbacks.add(logger)

        self._callbacks = callbacks
        self.logger = logger

        # also modify the callbacks of the stage
        stage.callbacks = self._callbacks.clone()
        stage.add_callback(self._LoopEventCallbackClass(self, stage))

        self.parent = parent
        self._child_stack = []

        self._stage = stage
        self._remote_doc = remote_doc
        self._batch_metrics = {}
        self._epoch_metrics = {}
        self._stage_metrics = {}

        # bind the events of this object
        self.events = EventHost()
        self.on_begin = self.events['on_begin']
        self.on_end = self.events['on_end']
        self.on_batch_begin = self.events['on_batch_begin']
        self.on_batch_end = self.events['on_batch_end']

    @property
    def batch(self) -> int:
        return self._stage.batch.index

    @property
    def max_batch(self) -> Optional[int]:
        return self._stage.batch.total

    def add_callback(self, callback: Callback):
        """Add a callback to this loop."""
        self._callbacks.add(callback)
        self._stage.add_callback(callback)

    def remove_callback(self, callback: Callback):
        """Remove a callback from this loop."""
        self._callbacks.remove(callback)
        self._stage.remove_callback(callback)

    def add_metrics(self,
                    metrics_: Optional[Dict[str, Any]] = None,
                    add_to_child_: bool = True,
                    **kwargs: Any) -> None:
        """
        Add metrics to the loop.

        Args:
            metrics_, \\**kwargs: The metrics to be collected.
                The names of the metrics will be ensured to have proper
                prefix, according to the loop type.
                See :meth:`mltk.StageType.add_metric_prefix` for more details.
            add_to_child_: If :obj:`True`, will add the metrics to the nearest
                child loop instead of adding to this loop, if any child loop
                context is currently open.
        """
        if self._child_stack and add_to_child_:
            self._child_stack[-1].add_metrics(metrics_, **kwargs)
        else:
            def collect(target: Dict[str, Any]):
                if metrics_:
                    for key, val in metrics_.items():
                        key = stage_type.add_metric_prefix(key)
                        target[key] = to_number_or_numpy(val)
                if kwargs:
                    for key, val in kwargs.items():
                        key = stage_type.add_metric_prefix(key)
                        target[key] = to_number_or_numpy(val)

            stage_type = self._stage.type
            if self._stage.batch.is_active:
                collect(self._batch_metrics)
            elif self._stage.epoch is not None and self._stage.epoch.is_active:
                collect(self._epoch_metrics)
            else:
                collect(self._stage_metrics)
                self._stage.push_metrics(self._stage_metrics)

    @contextmanager
    def timeit(self, metric_name: str):
        """
        Open a context, measure the elapsed time between entering and
        exiting the context, and add the time metric to this loop.

        Args:
            metric_name: The name of the time metric.
        """
        suffix = metric_name.rsplit('_', 1)[-1]
        if suffix not in ('time', 'timer'):
            raise ValueError(f'The metric name for a timer should end with '
                             f'suffix "_time" or "_timer": got metric name '
                             f'{metric_name!r}')
        start_time = time.time()
        try:
            yield
        finally:
            self.add_metrics({metric_name: time.time() - start_time})

    def _iter_batches(self,
                      data_generator: Optional[
                          Iterable[ArrayTupleOrList]] = None,
                      limit: Optional[int] = None,
                      count: Optional[int] = None,
                      ) -> BatchGenerator:
        # inspect the data generator to complete the total number of batches,
        # if `limit` and `count` is not specified
        if data_generator is not None and count is None and limit is None:
            g_info = inspect_data_generator(data_generator)
            if g_info.batch_count is not None and \
                    self._stage.batch.total is None:
                self._stage.batch.total = (self._stage.batch.index +
                                           g_info.batch_count)

        # get the upper limit of `batch.index`
        if limit is not None:
            batch_limit = limit
        elif count is not None:
            # `+1` because `batch.index` points to the previously completed
            # batch.
            batch_limit = self._stage.batch.index + count
        else:
            batch_limit = self._stage.batch.total

        if self._stage.batch.total is not None:
            batch_limit = min(self._stage.batch.total, batch_limit)

        # convert `data_generator` into iterator
        close_data_iterator = False
        if data_generator is not None:
            data_iterator = iter(data_generator)
            if isinstance(data_generator, DataStream):
                # we've just obtained a temporary iterator from the DataStream,
                # thus it's our responsibility to close it.
                close_data_iterator = True
        else:
            data_iterator = None

        # now run the loop
        try:
            if data_iterator is not None:
                while not self._stage.termination_requested and \
                        (batch_limit is None or
                         self._stage.batch.index < batch_limit):
                    try:
                        batch_data = next(data_iterator)
                    except StopIteration:
                        break

                    # check batch data and inspect batch size
                    if not isinstance(batch_data, (tuple, list)) or \
                            not batch_data:
                        raise ValueError(
                            f'`data_generator` did not yield a non-empty tuple '
                            f'or list of arrays: got {batch_data!r}'
                        )
                    batch_size = len(batch_data[0])

                    # now run the batch
                    self._batch_metrics.clear()
                    self._stage.enter_batch(batch_size=batch_size)
                    try:
                        yield self._stage.batch.index, batch_data
                    finally:
                        self._stage.exit_batch(self._batch_metrics)

            else:
                while self._stage.batch.index < batch_limit:
                    self._batch_metrics.clear()
                    self._stage.enter_batch()
                    try:
                        yield self._stage.batch.index
                    finally:
                        self._stage.exit_batch(self._batch_metrics)
        finally:
            if close_data_iterator:
                data_iterator.close()

    def iter_batches(self,
                     data_generator: Optional[
                         Iterable[ArrayTupleOrList]] = None,
                     limit: Optional[int] = None,
                     count: Optional[int] = None,
                     ) -> BatchGenerator:
        """
        Iterate through the batches.

        Args:
            data_generator: Mini-batch data generator, yielding tuple of arrays.
            limit: The maximum batch index to reach, i.e., ``index <= limit``
                is a loop constraint on the batch counter.
            count: The maximum number of batches to run.

        Yields:
            (int, Tuple[np.ndarray, ...]): The batch index and mini-batch
                arrays, if `data_generator` is specified.
            int: The batch index, if `data_generator` is not specified.
        """
        # check the context
        if not self._stage.is_active:
            raise RuntimeError('The loop context must be entered before '
                               'calling `iter_batches()`.')
        if self._stage.batch.is_active:
            raise RuntimeError('`iter_batches()` cannot be called when a '
                               'batch is currently running.')

        # check the arguments
        if count is not None and limit is not None:
            raise ValueError('`count` and `limit` cannot be both specified.')

        # we do not allow infinite loop
        if data_generator is None and count is None and limit is None and \
                self._stage.batch.total is None:
            raise ValueError(
                'Any one of `data_generator`, `limit` or `count` is required '
                'to be specified when `max_batch` is not configured for '
                'the loop.')

        return self._iter_batches(
            data_generator=data_generator,
            limit=limit,
            count=count,
        )

    def _complete_metrics_and_outputs_arg(self, metrics, outputs):
        if metrics is NOT_SET and outputs == ALL:
            metrics = ()
        elif outputs is NOT_SET and metrics == ALL:
            outputs = ()
        else:
            if metrics is NOT_SET:
                metrics = self.RUN_BATCHES_DEFAULT_METRICS
            if outputs is NOT_SET:
                outputs = self.RUN_BATCHES_DEFAULT_OUTPUTS
        return metrics, outputs

    def run_batches(self,
                    fn: Callable[..., Optional[Dict[str, Any]]],
                    data_generator: Iterable[ArrayTupleOrList],
                    limit: Optional[int] = None,
                    count: Optional[int] = None,
                    metrics: Union[Sequence[str], type(ALL)] = NOT_SET,
                    outputs: Union[Sequence[str], type(ALL)] = NOT_SET,
                    aggregators: Optional[Mapping[str, BatchAggregator]] = None,
                    excludes: Sequence[str] = ()
                    ) -> Optional[Dict[str, Any]]:
        """
        Run batches with the specified batch function `fn`.

        Args:
            fn: The batch function to execute at each batch.
                The signature of `fn` should be ``(*arrays) -> None` or
                ``(*arrays) -> Dict[str, Any]``, which consumes the batch
                arrays produced by `data_generator`, and (maybe) returns the
                batch metrics and outputs.
            data_generator: Mini-batch data generator, yielding tuple of arrays.
            limit: The maximum batch index to reach, i.e., ``index <= limit``
                is a loop constraint on the batch counter.
            count: The maximum number of batches to run.
            metrics: Names of metrics produced by `fn`.  These metrics will
                be aggregated by ``BatchAggregator('AVERAGE', axis=None)``,
                reported by ``self.logger``, and returned by this method.
                Defaults to ``SELF.RUN_BATCHES_DEFAULT_METRICS``.
            outputs: Names of outputs produced by `fn`.  These outputs will
                be aggregated by ``BatchAggregator('CONCAT', axis=0)``,
                and returned by this method.
                Defaults to ``SELF.RUN_BATCHES_DEFAULT_OUTPUTS``.
            aggregators: Dict from name to custom batch aggregators.
            excludes: The names to exclude, of items produced by `fn`.
                If a name is excluded, it will not be collected by any
                :class:`BatchAggregator`.

        Returns:
            The aggregated metrics and outputs.
        """
        metrics, outputs = \
            self._complete_metrics_and_outputs_arg(metrics, outputs)

        # the BatchAggregatorDict
        agg_dict = BatchAggregatorDict.new(
            metrics=metrics,
            outputs=outputs,
            aggregators=aggregators,
            excludes=excludes,
        )

        # now run the batches
        g = self.iter_batches(data_generator, limit=limit, count=count)
        try:
            for batch, batch_data in g:
                batch_size = get_array_shape(batch_data[0])[0]
                fn_out = fn(*batch_data)
                if fn_out is not None:
                    if not isinstance(fn_out, dict):
                        raise TypeError(f'The output of `fn` is expected to be '
                                        f'a dict, but got {fn_out!r}')

                    fn_out = {
                        k: to_number_or_numpy(v)
                        for k, v in fn_out.items()
                    }
                    metrics = {}
                    for key, val in fn_out.items():
                        agg = agg_dict.get(key)
                        if agg is not None:
                            size = batch_size if np.shape(val) == () else 1.
                            agg.add(val, weight=size)
                            # For metrics collected by
                            # ``BatchAggregator('AVERAGE', None)``, we also add
                            # them to the batch metrics.
                            if agg.mode == BatchAggregationMode.AVERAGE and \
                                    agg.axis is None:
                                metrics[key] = np.mean(val)
                    self.add_metrics(metrics)
        finally:
            g.close()

        # return the aggregated results
        if len(agg_dict) > 0:
            return {k: v.get() for k, v in agg_dict.items()}

    def __enter__(self):
        if self._stage.is_active:
            raise RuntimeError(f'{self.__class__.__qualname__} is not '
                               f're-entrant.')
        self._stage_metrics.clear()
        self._stage.enter()
        if self.parent is not None:
            self.parent._child_stack.append(self)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.parent is not None:
            self.parent._child_stack.pop()
        self._stage.exit(self._stage_metrics)


class AfterEveryFewCyclesCallback(object):

    loop: 'TrainLoop'
    fn: Callable[[], None]
    on_error: bool

    def __init__(self, fn: Callable[[], None], loop: 'TrainLoop', on_error: bool):
        self.fn = fn
        self.loop = loop
        self.on_error = on_error

    def _call(self):
        raise NotImplementedError()

    def __call__(self):
        if self.on_error or sys.exc_info()[0] is None:
            return self._call()


class AfterEveryFewEpochsCallback(AfterEveryFewCyclesCallback):

    epochs: int

    def __init__(self, fn: Callable[[], None], loop: 'TrainLoop', epochs: int,
                 on_error: bool):
        if epochs <= 0 or abs(epochs - int(epochs)) > 1e-6:
            raise ValueError(f'`epochs` must be a positive integer: got {epochs}')

        super().__init__(fn, loop, on_error)
        self.epochs = int(epochs)

    def _call(self):
        if self.loop.epoch % self.epochs == 0:
            return self.fn()


class AfterEveryFewBatchesCallback(AfterEveryFewCyclesCallback):

    batches: int

    def __init__(self, fn: Callable[[], None], loop: 'TrainLoop', batches: int,
                 on_error: bool):
        if batches <= 0 or abs(batches - int(batches)) > 1e-6:
            raise ValueError(f'`batches` must be a positive integer: got {batches}')

        super().__init__(fn, loop, on_error)
        self.batches = int(batches)

    def _call(self):
        if self.loop.batch % self.batches == 0:
            return self.fn()


class TrainLoop(BaseLoop):

    _LoopEventCallbackClass = _TrainLoopEventCallback

    only_batch: bool
    """Whether or not this train loop only runs batches, without epochs?"""

    on_epoch_begin: Event
    """Event triggered when epoch begins, with signature ``() -> None``."""

    on_epoch_end: Event
    """Event triggered when epoch ends, with signature ``() -> None``."""

    def __init__(self,
                 max_epoch: Optional[int] = None,
                 max_batch: Optional[int] = None,
                 only_batch: bool = False,
                 remote_doc: Optional[ExperimentDoc] = NOT_SET,
                 callbacks: Sequence[Callback] = ()):
        """
        Construct a new :class:`TrainLoop`.

        Args:
            max_epoch: The maximum index for the epoch counter to reach.
            max_batch: The maximnum index for the batch counter to reach.
            only_batch: Whether or not to iterate only through
                batches, without explicitly iterating through epochs.
                If :obj:`True`, will open an epoch automatically when entering
                the loop, and closing the epoch when exiting the loop.
            remote_doc: The experiment document object.
            callbacks: The callbacks.
        """
        only_batch = bool(only_batch)
        if only_batch and max_epoch is not None:
            raise ValueError('`epochs` must not be specified when '
                             '`only_batch` is set to True.')

        super().__init__(
            stage=Stage(
                type=StageType.TRAIN,
                max_epoch=max_epoch,
                max_batch=max_batch,
            ),
            remote_doc=remote_doc,
            callbacks=callbacks,
        )
        self.only_batch = only_batch
        self.on_epoch_begin = self.events['on_epoch_begin']
        self.on_epoch_end = self.events['on_epoch_end']

    @property
    def epoch(self):
        return self._stage.epoch.index

    @property
    def max_epoch(self):
        return self._stage.epoch.total

    def run_after_every(self,
                        fn: Callable[[], None],
                        *,
                        epochs: Optional[int] = None,
                        batches: Optional[int] = None,
                        on_error: bool = False,
                        ) -> Optional[AfterEveryFewCyclesCallback]:
        """
        Register a callback that runs after every few epochs or batches.

        Args:
            fn: The callback to run.
            epochs: The number of epochs.
            batches: The number of batches.
            on_error: If an error occurs, will run `fn` only if this is True.

        Returns:
            Returns a callback object, which can be un-registered via
            :meth:`remove_after_every`, if either `epochs` or `batches`
            is specified.
        """
        if epochs is not None and batches is not None:
            raise ValueError('`epochs` and `batches` cannot be both specified.')

        if epochs is not None:
            cb = AfterEveryFewEpochsCallback(fn, self, epochs, on_error)
            self.on_epoch_end.do(cb)
        elif batches is not None:
            cb = AfterEveryFewBatchesCallback(fn, self, batches, on_error)
            self.on_batch_end.do(cb)
        else:
            cb = None

        return cb

    def remove_after_every(self, cb: Optional[AfterEveryFewCyclesCallback]):
        """
        Remove a callback registered by :meth:`run_after_every()`.

        Args:
            cb: The callback object.
        """
        if cb is not None:
            if isinstance(cb, AfterEveryFewEpochsCallback):
                self.on_epoch_end.cancel_do(cb)
            elif isinstance(cb, AfterEveryFewBatchesCallback):
                self.on_batch_end.cancel_do(cb)
            else:  # pragma: no cover
                raise TypeError(f'Unsupported callback: {cb!r}')

    def iter_batches(self,
                     data_generator: Optional[
                         Iterable[ArrayTupleOrList]] = None,
                     limit: Optional[int] = None,
                     count: Optional[int] = None
                     ) -> BatchGenerator:
        if not self._stage.epoch.is_active:
            raise RuntimeError(
                'The batch loop can only be open inside an epoch loop.  '
                'Did you forget to call `iter_epochs()`?'
            )
        return super().iter_batches(
            data_generator=data_generator,
            limit=limit,
            count=count,
        )

    def _iter_epochs(self,
                     limit: Optional[int] = None,
                     count: Optional[int] = None
                     ) -> Generator[int, None, None]:
        # get the upper limit of `batch.index`
        if limit is not None:
            epoch_limit = limit
        elif count is not None:
            # see `iter_batches()` for the reason of `+1`
            epoch_limit = self._stage.epoch.index + count
        else:
            epoch_limit = self._stage.epoch.total

        if self._stage.epoch.total is not None:
            epoch_limit = min(self._stage.epoch.total, epoch_limit)

        # now run the loop
        while not self._stage.termination_requested and \
                self._stage.epoch.index < epoch_limit:
            self._epoch_metrics.clear()
            self._stage.enter_epoch()
            try:
                yield self._stage.epoch.index
            finally:
                self._stage.exit_epoch(self._epoch_metrics)

    def iter_epochs(self,
                    limit: Optional[int] = None,
                    count: Optional[int] = None
                    ) -> Generator[int, None, None]:
        """
        Iterate through the batches.

        Args:
            limit: The maximum epoch index to reach, i.e., ``index <= limit``
                is a loop constraint on the epoch counter.
            count: The maximum number of epochs to run.

        Yields:
            int: The epoch index.
        """
        # check the context
        if self.only_batch:
            raise RuntimeError('The loop is configured with `only_batch = True`'
                               ', thus `iter_epochs()` is prohibited.')
        if not self._stage.is_active:
            raise RuntimeError('The loop context must be entered before '
                               'calling `iter_epochs()`.')
        if self._stage.epoch.is_active:
            raise RuntimeError('`iter_epochs()` is not re-entrant.')

        # check the arguments
        if count is not None and limit is not None:
            raise ValueError('`count` and `limit` cannot be both specified.')

        # we do not allow infinite loop
        if limit is None and count is None and self._stage.epoch.total is None:
            raise ValueError(
                'Either `limit` or `count` is required to be specified when '
                '`max_epoch` is not configured for the loop.')

        return self._iter_epochs(limit=limit, count=count)

    def run_epochs(self,
                   fn: Callable[..., Optional[Dict[str, Any]]],
                   data_generator: Iterable[ArrayTupleOrList],
                   limit: Optional[int] = None,
                   count: Optional[int] = None,
                   metrics: Union[Sequence[str], type(ALL)] = NOT_SET,
                   excludes: Sequence[str] = ()
                   ) -> None:
        """
        Run epochs and the batches in each epoch with the specified batch
        function `fn`.

        Args:
            fn: The batch function to execute at each batch.
                The signature of `fn` should be ``(*arrays) -> None` or
                ``(*arrays) -> Dict[str, Any]``, which consumes the batch
                arrays produced by `data_generator`, and (maybe) returns the
                batch metrics.
            data_generator: Mini-batch data generator, yielding tuple of arrays.
            limit: The maximum epoch index to reach, i.e., ``index <= limit``
                is a loop constraint on the epoch counter.
            count: The maximum number of epochs to run.
            metrics: Names of metrics produced by `fn`.  These metrics will
                be aggregated by ``BatchAggregator('AVERAGE', axis=None)``,
                and reported by ``self.logger``.
            excludes: The names to exclude, of items produced by `fn`.
                If a name is excluded, it will not be collected by any
                :class:`BatchAggregator`.

        Notes:
            Unlike :meth:`run_batches()`, this method will not return the
            collected metrics.  Consider to use :meth:`run_batches()` with
            explicit epoch loop if you need to obtain the metrics.
        """
        g = self.iter_epochs(limit=limit, count=count)
        try:
            for _ in g:
                self.run_batches(
                    fn, data_generator, metrics=metrics, excludes=excludes)
        finally:
            g.close()

    def run(self,
            fn: Callable[..., Optional[Dict[str, Any]]],
            data_generator: Iterable[ArrayTupleOrList],
            metrics: Union[Sequence[str], type(ALL)] = NOT_SET,
            excludes: Sequence[str] = (),
            **kwargs
            ) -> Optional[Dict[str, Any]]:
        """
        Run the train loop.

        Args:
            fn: The batch function to execute at each batch.
                The signature of `fn` should be ``(*arrays) -> None` or
                ``(*arrays) -> Dict[str, Any]``, which consumes the batch
                arrays produced by `data_generator`, and (maybe) returns the
                batch metrics.
            data_generator: Mini-batch data generator, yielding tuple of arrays.
            metrics: Names of metrics produced by `fn`.  These metrics will
                be aggregated by ``BatchAggregator('AVERAGE', axis=None)``,
                and reported by ``self.logger``.
            excludes: The names to exclude, of items produced by `fn`.
                If a name is excluded, it will not be collected by any
                :class:`BatchAggregator`.
            \\**kwargs: Named parameters passed to `run_batches(...)` or
                `run_epochs(...)`.

        Returns:
            If ``self.only_batch == True``, then the collected metrics will
            be returned.  Otherwise the return value will alwasy be :obj:`None`.
        """
        run_fn = self.run_batches if self.only_batch else self.run_epochs
        F = lambda: run_fn(
            fn,
            data_generator,
            metrics=metrics,
            excludes=excludes,
            **kwargs
        )
        if not self._stage.is_active:
            with self:
                return F()
        else:
            return F()

    def validation(self) -> 'ValidationLoop':
        """
        Construct a new :class:`ValidationLoop` that inherits callbacks and
        other states from this train loop.

        This is the recommended way to obtain a validation loop inside a train loop.
        """
        return ValidationLoop(
            remote_doc=self._remote_doc,
            callbacks=self._callbacks,
            parent=self,
        )

    def test(self) -> 'TestLoop':
        """
        Construct a new :class:`TestLoop` that inherits callbacks and
        other states from this train loop.

        This is the recommended way to obtain a test loop inside a train loop.
        """
        return TestLoop(
            remote_doc=self._remote_doc,
            callbacks=self._callbacks,
            parent=self,
        )

    def predict(self) -> 'PredictLoop':
        """
        Construct a new :class:`PredictLoop` that inherits callbacks and
        other states from this train loop.

        This is the recommended way to obtain a predict loop inside a train loop.
        """
        return PredictLoop(
            remote_doc=self._remote_doc,
            callbacks=self._callbacks,
            parent=self,
        )

    def __enter__(self):
        super().__enter__()
        # open the first epoch if `only_batches` is True
        if self.only_batch:
            self._stage.enter_epoch(1)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.only_batch:
            self._stage.exit_epoch(self._epoch_metrics)
        return super().__exit__(exc_type, exc_val, exc_tb)


class _BatchOnlyLoop(BaseLoop):

    def run(self,
            fn: Callable[..., Optional[Dict[str, Any]]],
            data_generator: Iterable[ArrayTupleOrList],
            metrics: Union[Sequence[str], type(ALL)] = NOT_SET,
            outputs: Union[Sequence[str], type(ALL)] = NOT_SET,
            aggregators: Optional[Mapping[str, BatchAggregator]] = None,
            excludes: Sequence[str] = ()
            ) -> Optional[Dict[str, Any]]:
        """
        Run the loop.

        Args:
            fn: The batch function to execute at each batch.
                The signature of `fn` should be ``(*arrays) -> None` or
                ``(*arrays) -> Dict[str, Any]``, which consumes the batch
                arrays produced by `data_generator`, and (maybe) returns the
                batch metrics or outputs.
            data_generator: Mini-batch data generator, yielding tuple of arrays.
            metrics: Names of metrics produced by `fn`.  These metrics will
                be aggregated by ``BatchAggregator('AVERAGE', axis=None)``,
                reported by ``self.logger``, and returned by this method.
                Defaults to ``SELF.RUN_BATCHES_DEFAULT_METRICS``.
            outputs: Names of outputs produced by `fn`.  These outputs will
                be aggregated by ``BatchAggregator('CONCAT', axis=0)``,
                and returned by this method.
                Defaults to ``SELF.RUN_BATCHES_DEFAULT_OUTPUTS``.
            aggregators: Dict from name to custom batch aggregators.
            excludes: The names to exclude, of items produced by `fn`.
                If a name is excluded, it will not be collected by any
                :class:`BatchAggregator`.
        """
        if not self._stage.is_active:
            with self:
                return self.run_batches(
                    fn, data_generator, metrics=metrics, outputs=outputs,
                    aggregators=aggregators, excludes=excludes
                )
        else:
            return self.run_batches(
                fn, data_generator, metrics=metrics, outputs=outputs,
                aggregators=aggregators, excludes=excludes
            )


class ValidationLoop(_BatchOnlyLoop):

    def __init__(self,
                 remote_doc: Optional[ExperimentDoc] = NOT_SET,
                 callbacks: Sequence[Callback] = (),
                 parent: Optional[BaseLoop] = None):
        super().__init__(
            stage=Stage(type=StageType.VALIDATION),
            remote_doc=remote_doc,
            callbacks=callbacks,
            parent=parent,
        )


class TestLoop(_BatchOnlyLoop):

    def __init__(self,
                 remote_doc: Optional[ExperimentDoc] = NOT_SET,
                 callbacks: Sequence[Callback] = (),
                 parent: Optional[BaseLoop] = None):
        super().__init__(
            stage=Stage(type=StageType.TEST),
            remote_doc=remote_doc,
            callbacks=callbacks,
            parent=parent,
        )


class PredictLoop(_BatchOnlyLoop):

    RUN_BATCHES_DEFAULT_METRICS = ()
    RUN_BATCHES_DEFAULT_OUTPUTS = ALL

    def __init__(self,
                 remote_doc: Optional[ExperimentDoc] = NOT_SET,
                 callbacks: Sequence[Callback] = (),
                 parent: Optional[BaseLoop] = None):
        super().__init__(
            stage=Stage(type=StageType.PREDICT),
            remote_doc=remote_doc,
            callbacks=callbacks,
            parent=parent,
        )


@dataclass
class DataGeneratorInfo(object):

    __slots__ = ('data_length', 'batch_size', 'batch_count')

    data_length: Optional[int]
    batch_size: Optional[int]
    batch_count: Optional[int]


def inspect_data_generator(g) -> Union[DataGeneratorInfo, Any]:
    if isinstance(g, DataStream):
        # since `DataStream` has all the interface of `DataGeneratorInfo`,
        # we just return it without constructing a new object
        return g
    return DataGeneratorInfo(data_length=None, batch_size=None,
                             batch_count=None)