Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import operator
from enum import Enum
from functools import reduce
from typing import *
import numpy as np
from .stage import StageType
from .utils import ALL, NOT_SET
__all__ = [
'BatchAggregationMode',
'BatchAggregator', 'BatchAggregatorDict',
]
class BatchAggregationMode(str, Enum):
CONCAT = 'CONCAT'
"""To concat the batch arrays along specified axis."""
SUM = 'SUM'
"""To sum the batch arrays along specified axis."""
AVERAGE = 'AVERAGE'
"""To average the batch arrays along specified axis."""
class BatchAggregator(object):
"""
Class to aggregate batch arrays.
>>> agg = BatchAggregator(BatchAggregationMode.CONCAT)
>>> agg
BatchAggregator(mode=CONCAT, axis=0)
>>> agg.add(np.array([1, 2, 3, 4]))
>>> agg.add(np.array([5, 6]))
>>> agg.get()
array([1, 2, 3, 4, 5, 6])
>>> agg = BatchAggregator(BatchAggregationMode.AVERAGE)
>>> agg
BatchAggregator(mode=AVERAGE, axis=None)
>>> agg.add(np.array([1, 2, 3, 4]))
>>> agg.add(np.array([5, 6]))
>>> agg.get()
3.5
>>> agg = BatchAggregator(BatchAggregationMode.SUM)
>>> agg
BatchAggregator(mode=SUM, axis=None)
>>> agg.add(np.array([1, 2, 3, 4]))
>>> agg.add(np.array([5, 6]))
>>> agg.get()
21
"""
mode: BatchAggregationMode
axis: Union[int, Tuple[int, ...]]
def __init__(self,
mode: Union[str, BatchAggregationMode],
axis: Optional[Union[int, Tuple[int, ...], List[int]]] = NOT_SET):
"""
Construct a new :class:`BatchAggregator`.
Args:
mode: Aggregation mode.
axis: The axis to aggregate. Defaults to `0` for `CONCAT` mode,
while :obj:`None` for `SUM` and `AVERAGE` mode.
"""
mode = BatchAggregationMode(mode)
if axis is NOT_SET:
axis = 0 if mode == BatchAggregationMode.CONCAT else None
if mode == BatchAggregationMode.CONCAT:
if not isinstance(axis, int):
raise TypeError('`axis` must be a int when `mode` is CONCAT.')
if axis is not None:
if hasattr(axis, '__iter__'):
axis = tuple(int(v) for v in axis)
if len(axis) == 1:
axis = axis[0]
else:
axis = int(axis)
self.mode = mode
self.axis = axis
self._buf = None
self._weight_sum = 0.
def __repr__(self):
return f'{self.__class__.__qualname__}' \
f'(mode={self.mode.value}, axis={self.axis})'
def get(self) -> Optional[np.ndarray]:
"""
Get the aggregation result.
Returns:
The result, or :obj:`None` if no value has been collected.
"""
if self._buf is not None:
if self.mode == BatchAggregationMode.CONCAT:
return np.concatenate(self._buf, axis=self.axis)
else:
return self._buf
def add(self,
values: np.ndarray,
weight: Optional[float] = 1.):
"""
Add a batch array to the aggregator.
Args:
values: The batch array.
weight: The batch weight, used only in `AVERAGE` mode.
"""
# CONCAT: append the values to the buf
if self.mode == BatchAggregationMode.CONCAT:
if self._buf is None:
self._buf = []
self._buf.append(values)
# SUM
elif self.mode == BatchAggregationMode.SUM:
batch_sum = np.sum(values, axis=self.axis)
if self._buf is None:
self._buf = batch_sum
else:
self._buf += batch_sum
# AVERAGE: maintain the `total_weight` state and update the buf
else:
# infer the batch size and weight
batch_shape = np.shape(values)
if self.axis is None:
batch_size = float(reduce(operator.mul, np.shape(values), 1.))
elif isinstance(self.axis, tuple):
batch_size = 1.
for a in self.axis:
batch_size *= batch_shape[a]
else:
batch_size = batch_shape[self.axis]
batch_weight = weight * batch_size
# do update the weight
self._weight_sum += batch_weight
r1 = weight / self._weight_sum
batch_sum = np.sum(values, axis=self.axis)
if self._buf is None:
self._buf = r1 * batch_sum
else:
r2 = batch_weight / self._weight_sum
self._buf += r1 * batch_sum - r2 * self._buf
class BatchAggregatorDict(Mapping[str, BatchAggregator]):
"""
Maintain a dict of :class:`BatchAggregator` instances, maybe with
a default factory to construct :class:`BatchAggregator` instance
for new keys.
>>> agg_dict = BatchAggregatorDict.new()
>>> agg_dict['acc'].add(np.array([0.75, 0.875]))
>>> agg_dict['loss'].add(np.array([0.125, 0.2]))
>>> len(agg_dict)
2
>>> list(agg_dict)
['acc', 'loss']
>>> agg_dict['acc'].get()
0.8125
>>> agg_dict['loss'].get()
0.1625
"""
@staticmethod
def new(metrics: Union[Sequence[str], type(ALL)] = ALL,
outputs: Union[Sequence[str], type(ALL)] = (),
aggregators: Optional[Mapping[str, BatchAggregator]] = None,
excludes: Sequence[str] = (),
stage_type: Optional[StageType] = None) -> 'BatchAggregatorDict':
"""
Construct a new :class:`BatchAggregatorDict` according to the field
settings `metrics`, `outputs` and `aggregators`.
Args:
metrics: The names of the batch arrays, which should be aggregated
by ``BatchAggregator('AVERAGE', axis=None)``. :obj:`ALL`
indicates that an array is by default a metric if it is neither
specified in `outputs` nor in `aggregator`.
outputs: The names of the batch arrays, which should be aggregated
by ``BatchAggregator('CONCAT', axis=0)``. :obj:`ALL`
indicates that an array is by default an output if it is neither
specified in `outputs` nor in `aggregator`.
aggregators: The dict of names and their corresponding aggregators.
excludes: The names to exclude. If a name is excluded, no
aggregator will be designated to this name, i.e., ``get(name)``
returns None, and ``__getitem__(name)`` raises `KeyError`.
stage_type: If specified, will add stage metric prefix to the keys
of `metrics`, `outputs` and `aggregators`.
Returns:
The aggregator dict.
Notes:
:obj:`ALL` could be specified to at most one of `metrics`
and `outputs`. The argument `aggregators` has higher priority
than `outputs`, and so does `outputs` have higher priority than
`metrics`. That is to say, if a name is specified in both
`aggregators` and `outputs`, then the aggregator specified in
`aggregators` will be chosen; this is also true if a name is
specified in both `outputs` and `metrics`.
"""
# the aggregator factories
average_aggregator_factory = lambda: \
BatchAggregator(mode=BatchAggregationMode.AVERAGE, axis=None)
concat_aggregator_factory = lambda: \
BatchAggregator(mode=BatchAggregationMode.CONCAT, axis=0)
# determine the default factory
if metrics == ALL and outputs == ALL:
raise ValueError('Only one of `metrics` and `outputs` can be '
'`ALL`.')
elif metrics == ALL:
default_factory = average_aggregator_factory
elif outputs == ALL:
default_factory = concat_aggregator_factory
else:
default_factory = None
# build the aggregator instances
agg_dict = {}
if metrics != ALL and metrics:
for key in metrics:
if stage_type is not None:
key = stage_type.add_metric_prefix(key)
agg_dict[key] = average_aggregator_factory()
if outputs != ALL and outputs:
for key in outputs:
if stage_type is not None:
key = stage_type.add_metric_prefix(key)
agg_dict[key] = concat_aggregator_factory()
if aggregators:
for key, agg in aggregators.items():
if stage_type is not None:
key = stage_type.add_metric_prefix(key)
agg_dict[key] = agg
# build the excludes names
if excludes and stage_type is not None:
excludes = [stage_type.add_metric_prefix(n) for n in excludes]
# now construct the `BatchAggregatorDict` instance
return BatchAggregatorDict(
agg_dict, excludes=excludes, default_factory=default_factory)
def __init__(self,
aggregators: Mapping[str, BatchAggregator],
excludes: Sequence[str] = (),
default_factory: Optional[
Callable[[], BatchAggregator]] = None):
"""
Construct a new :class:`BatchAggregatorDict`.
Args:
aggregators: The mapping from names to aggregators.
excludes: The names to exclude from this dict. If a name is
excluded, no aggregator will be designated to this name,
i.e., ``get(name)`` returns None, and ``__getitem__(name)``
raises :class:`KeyError`.
default_factory: The default factory, which is used to create
new :class:`BatchAggregator` instances if the aggregator
to a requested name does not exist. If not specified,
accessing non-existing name will raise an error.
"""
self._aggregators = {}
self._excludes = set(excludes or ())
self._default_factory = default_factory
for key in aggregators:
if key not in self._excludes:
agg = aggregators[key]
if not isinstance(agg, BatchAggregator):
raise TypeError(f'Item {key!r} is not an instance of '
f'{BatchAggregator.__qualname__}: '
f'{agg!r}')
self._aggregators[key] = agg
def get(self, item: str, default: Any = None) -> Optional[BatchAggregator]:
if item not in self._excludes:
if item not in self._aggregators:
if self._default_factory is not None:
self._aggregators[item] = self._default_factory()
else:
return default
return self._aggregators[item]
def __getitem__(self, item: str) -> BatchAggregator:
ret = self.get(item)
if ret is None:
raise KeyError(item)
return ret
def __len__(self) -> int:
return len(self._aggregators)
def __iter__(self) -> Iterator[str]:
return iter(self._aggregators)