Skip to content
代码片段 群组 项目
batch_agg.py 11.0 KB
Newer Older
openaiops's avatar
openaiops 已提交
import operator
from enum import Enum
from functools import reduce
from typing import *

import numpy as np

from .stage import StageType
from .utils import ALL, NOT_SET

__all__ = [
    'BatchAggregationMode',
    'BatchAggregator', 'BatchAggregatorDict',
]


class BatchAggregationMode(str, Enum):

    CONCAT = 'CONCAT'
    """To concat the batch arrays along specified axis."""

    SUM = 'SUM'
    """To sum the batch arrays along specified axis."""

    AVERAGE = 'AVERAGE'
    """To average the batch arrays along specified axis."""


class BatchAggregator(object):
    """
    Class to aggregate batch arrays.

    >>> agg = BatchAggregator(BatchAggregationMode.CONCAT)
    >>> agg
    BatchAggregator(mode=CONCAT, axis=0)
    >>> agg.add(np.array([1, 2, 3, 4]))
    >>> agg.add(np.array([5, 6]))
    >>> agg.get()
    array([1, 2, 3, 4, 5, 6])

    >>> agg = BatchAggregator(BatchAggregationMode.AVERAGE)
    >>> agg
    BatchAggregator(mode=AVERAGE, axis=None)
    >>> agg.add(np.array([1, 2, 3, 4]))
    >>> agg.add(np.array([5, 6]))
    >>> agg.get()
    3.5

    >>> agg = BatchAggregator(BatchAggregationMode.SUM)
    >>> agg
    BatchAggregator(mode=SUM, axis=None)
    >>> agg.add(np.array([1, 2, 3, 4]))
    >>> agg.add(np.array([5, 6]))
    >>> agg.get()
    21
    """

    mode: BatchAggregationMode
    axis: Union[int, Tuple[int, ...]]

    def __init__(self,
                 mode: Union[str, BatchAggregationMode],
                 axis: Optional[Union[int, Tuple[int, ...], List[int]]] = NOT_SET):
        """
        Construct a new :class:`BatchAggregator`.

        Args:
            mode: Aggregation mode.
            axis: The axis to aggregate.  Defaults to `0` for `CONCAT` mode,
                while :obj:`None` for `SUM` and `AVERAGE` mode.
        """
        mode = BatchAggregationMode(mode)

        if axis is NOT_SET:
            axis = 0 if mode == BatchAggregationMode.CONCAT else None
        if mode == BatchAggregationMode.CONCAT:
            if not isinstance(axis, int):
                raise TypeError('`axis` must be a int when `mode` is CONCAT.')
        if axis is not None:
            if hasattr(axis, '__iter__'):
                axis = tuple(int(v) for v in axis)
                if len(axis) == 1:
                    axis = axis[0]
            else:
                axis = int(axis)

        self.mode = mode
        self.axis = axis
        self._buf = None
        self._weight_sum = 0.

    def __repr__(self):
        return f'{self.__class__.__qualname__}' \
               f'(mode={self.mode.value}, axis={self.axis})'

    def get(self) -> Optional[np.ndarray]:
        """
        Get the aggregation result.

        Returns:
            The result, or :obj:`None` if no value has been collected.
        """
        if self._buf is not None:
            if self.mode == BatchAggregationMode.CONCAT:
                return np.concatenate(self._buf, axis=self.axis)
            else:
                return self._buf

    def add(self,
            values: np.ndarray,
            weight: Optional[float] = 1.):
        """
        Add a batch array to the aggregator.

        Args:
            values: The batch array.
            weight: The batch weight, used only in `AVERAGE` mode.
        """
        # CONCAT: append the values to the buf
        if self.mode == BatchAggregationMode.CONCAT:
            if self._buf is None:
                self._buf = []
            self._buf.append(values)

        # SUM
        elif self.mode == BatchAggregationMode.SUM:
            batch_sum = np.sum(values, axis=self.axis)
            if self._buf is None:
                self._buf = batch_sum
            else:
                self._buf += batch_sum

        # AVERAGE: maintain the `total_weight` state and update the buf
        else:
            # infer the batch size and weight
            batch_shape = np.shape(values)
            if self.axis is None:
                batch_size = float(reduce(operator.mul, np.shape(values), 1.))
            elif isinstance(self.axis, tuple):
                batch_size = 1.
                for a in self.axis:
                    batch_size *= batch_shape[a]
            else:
                batch_size = batch_shape[self.axis]
            batch_weight = weight * batch_size

            # do update the weight
            self._weight_sum += batch_weight
            r1 = weight / self._weight_sum
            batch_sum = np.sum(values, axis=self.axis)
            if self._buf is None:
                self._buf = r1 * batch_sum
            else:
                r2 = batch_weight / self._weight_sum
                self._buf += r1 * batch_sum - r2 * self._buf


class BatchAggregatorDict(Mapping[str, BatchAggregator]):
    """
    Maintain a dict of :class:`BatchAggregator` instances, maybe with
    a default factory to construct :class:`BatchAggregator` instance
    for new keys.

    >>> agg_dict = BatchAggregatorDict.new()
    >>> agg_dict['acc'].add(np.array([0.75, 0.875]))
    >>> agg_dict['loss'].add(np.array([0.125, 0.2]))
    >>> len(agg_dict)
    2
    >>> list(agg_dict)
    ['acc', 'loss']
    >>> agg_dict['acc'].get()
    0.8125
    >>> agg_dict['loss'].get()
    0.1625
    """

    @staticmethod
    def new(metrics: Union[Sequence[str], type(ALL)] = ALL,
            outputs: Union[Sequence[str], type(ALL)] = (),
            aggregators: Optional[Mapping[str, BatchAggregator]] = None,
            excludes: Sequence[str] = (),
            stage_type: Optional[StageType] = None) -> 'BatchAggregatorDict':
        """
        Construct a new :class:`BatchAggregatorDict` according to the field
        settings `metrics`, `outputs` and `aggregators`.

        Args:
            metrics: The names of the batch arrays, which should be aggregated
                by ``BatchAggregator('AVERAGE', axis=None)``.  :obj:`ALL`
                indicates that an array is by default a metric if it is neither
                specified in `outputs` nor in `aggregator`.
            outputs: The names of the batch arrays, which should be aggregated
                by ``BatchAggregator('CONCAT', axis=0)``.  :obj:`ALL`
                indicates that an array is by default an output if it is neither
                specified in `outputs` nor in `aggregator`.
            aggregators: The dict of names and their corresponding aggregators.
            excludes: The names to exclude.  If a name is excluded, no
                aggregator will be designated to this name, i.e., ``get(name)``
                returns None, and ``__getitem__(name)`` raises `KeyError`.
            stage_type: If specified, will add stage metric prefix to the keys
                of `metrics`, `outputs` and `aggregators`.

        Returns:
            The aggregator dict.

        Notes:
            :obj:`ALL` could be specified to at most one of `metrics`
            and `outputs`.  The argument `aggregators` has higher priority
            than `outputs`, and so does `outputs` have higher priority than
            `metrics`.  That is to say, if a name is specified in both
            `aggregators` and `outputs`, then the aggregator specified in
            `aggregators` will be chosen; this is also true if a name is
            specified in both `outputs` and `metrics`.
        """
        # the aggregator factories
        average_aggregator_factory = lambda: \
            BatchAggregator(mode=BatchAggregationMode.AVERAGE, axis=None)
        concat_aggregator_factory = lambda: \
            BatchAggregator(mode=BatchAggregationMode.CONCAT, axis=0)

        # determine the default factory
        if metrics == ALL and outputs == ALL:
            raise ValueError('Only one of `metrics` and `outputs` can be '
                             '`ALL`.')
        elif metrics == ALL:
            default_factory = average_aggregator_factory
        elif outputs == ALL:
            default_factory = concat_aggregator_factory
        else:
            default_factory = None

        # build the aggregator instances
        agg_dict = {}
        if metrics != ALL and metrics:
            for key in metrics:
                if stage_type is not None:
                    key = stage_type.add_metric_prefix(key)
                agg_dict[key] = average_aggregator_factory()
        if outputs != ALL and outputs:
            for key in outputs:
                if stage_type is not None:
                    key = stage_type.add_metric_prefix(key)
                agg_dict[key] = concat_aggregator_factory()
        if aggregators:
            for key, agg in aggregators.items():
                if stage_type is not None:
                    key = stage_type.add_metric_prefix(key)
                agg_dict[key] = agg

        # build the excludes names
        if excludes and stage_type is not None:
            excludes = [stage_type.add_metric_prefix(n) for n in excludes]

        # now construct the `BatchAggregatorDict` instance
        return BatchAggregatorDict(
            agg_dict, excludes=excludes, default_factory=default_factory)

    def __init__(self,
                 aggregators: Mapping[str, BatchAggregator],
                 excludes: Sequence[str] = (),
                 default_factory: Optional[
                     Callable[[], BatchAggregator]] = None):
        """
        Construct a new :class:`BatchAggregatorDict`.

        Args:
            aggregators: The mapping from names to aggregators.
            excludes: The names to exclude from this dict.  If a name is
                excluded, no aggregator will be designated to this name,
                i.e., ``get(name)`` returns None, and ``__getitem__(name)``
                raises :class:`KeyError`.
            default_factory: The default factory, which is used to create
                new :class:`BatchAggregator` instances if the aggregator
                to a requested name does not exist.  If not specified,
                accessing non-existing name will raise an error.
        """
        self._aggregators = {}
        self._excludes = set(excludes or ())
        self._default_factory = default_factory

        for key in aggregators:
            if key not in self._excludes:
                agg = aggregators[key]
                if not isinstance(agg, BatchAggregator):
                    raise TypeError(f'Item {key!r} is not an instance of '
                                    f'{BatchAggregator.__qualname__}: '
                                    f'{agg!r}')
                self._aggregators[key] = agg

    def get(self, item: str, default: Any = None) -> Optional[BatchAggregator]:
        if item not in self._excludes:
            if item not in self._aggregators:
                if self._default_factory is not None:
                    self._aggregators[item] = self._default_factory()
                else:
                    return default
            return self._aggregators[item]

    def __getitem__(self, item: str) -> BatchAggregator:
        ret = self.get(item)
        if ret is None:
            raise KeyError(item)
        return ret

    def __len__(self) -> int:
        return len(self._aggregators)

    def __iter__(self) -> Iterator[str]:
        return iter(self._aggregators)